core/tests/components/stt/common.py

161 lines
4.6 KiB
Python

"""Provide common test tools for STT."""
from __future__ import annotations
from collections.abc import AsyncIterable, Callable, Coroutine
from pathlib import Path
from typing import Any
from homeassistant.components.stt import (
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
Provider,
SpeechMetadata,
SpeechResult,
SpeechResultState,
SpeechToTextEntity,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from tests.common import MockPlatform, mock_platform
TEST_DOMAIN = "test"
class BaseProvider:
"""Mock STT provider."""
fail_process_audio = False
def __init__(
self, *, supported_languages: list[str] | None = None, text: str = "test_result"
) -> None:
"""Init test provider."""
self._supported_languages = supported_languages or ["de", "de-CH", "en"]
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
self.received: list[bytes] = []
self.text = text
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return self._supported_languages
@property
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV, AudioFormats.OGG]
@property
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM, AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bitrates."""
return [AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported samplerates."""
return [AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream."""
self.calls.append((metadata, stream))
async for data in stream:
if not data:
break
self.received.append(data)
if self.fail_process_audio:
return SpeechResult(None, SpeechResultState.ERROR)
return SpeechResult(self.text, SpeechResultState.SUCCESS)
class MockSTTProvider(BaseProvider, Provider):
"""Mock provider."""
url_path = TEST_DOMAIN
class MockSTTProviderEntity(BaseProvider, SpeechToTextEntity):
"""Mock provider entity."""
url_path = "stt.test"
_attr_name = "test"
class MockSTTPlatform(MockPlatform):
"""Help to set up test stt service."""
def __init__(
self,
async_get_engine: Callable[
[HomeAssistant, ConfigType, DiscoveryInfoType | None],
Coroutine[Any, Any, Provider | None],
]
| None = None,
get_engine: Callable[
[HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None
]
| None = None,
) -> None:
"""Return the stt service."""
super().__init__()
if get_engine:
self.get_engine = get_engine
if async_get_engine:
self.async_get_engine = async_get_engine
def mock_stt_platform(
hass: HomeAssistant,
tmp_path: Path,
integration: str = "stt",
async_get_engine: Callable[
[HomeAssistant, ConfigType, DiscoveryInfoType | None],
Coroutine[Any, Any, Provider | None],
]
| None = None,
get_engine: Callable[
[HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None
]
| None = None,
):
"""Specialize the mock platform for stt."""
loaded_platform = MockSTTPlatform(async_get_engine, get_engine)
mock_platform(hass, f"{integration}.stt", loaded_platform)
return loaded_platform
def mock_stt_entity_platform(
hass: HomeAssistant,
tmp_path: Path,
integration: str,
async_setup_entry: Callable[
[HomeAssistant, ConfigEntry, AddEntitiesCallback],
Coroutine[Any, Any, None],
]
| None = None,
) -> MockPlatform:
"""Specialize the mock platform for stt."""
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry)
mock_platform(hass, f"{integration}.stt", loaded_platform)
return loaded_platform