mirror of https://github.com/home-assistant/core
161 lines
4.6 KiB
Python
161 lines
4.6 KiB
Python
"""Provide common test tools for STT."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterable, Callable, Coroutine
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from homeassistant.components.stt import (
|
|
AudioBitRates,
|
|
AudioChannels,
|
|
AudioCodecs,
|
|
AudioFormats,
|
|
AudioSampleRates,
|
|
Provider,
|
|
SpeechMetadata,
|
|
SpeechResult,
|
|
SpeechResultState,
|
|
SpeechToTextEntity,
|
|
)
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
|
|
|
from tests.common import MockPlatform, mock_platform
|
|
|
|
TEST_DOMAIN = "test"
|
|
|
|
|
|
class BaseProvider:
|
|
"""Mock STT provider."""
|
|
|
|
fail_process_audio = False
|
|
|
|
def __init__(
|
|
self, *, supported_languages: list[str] | None = None, text: str = "test_result"
|
|
) -> None:
|
|
"""Init test provider."""
|
|
self._supported_languages = supported_languages or ["de", "de-CH", "en"]
|
|
self.calls: list[tuple[SpeechMetadata, AsyncIterable[bytes]]] = []
|
|
self.received: list[bytes] = []
|
|
self.text = text
|
|
|
|
@property
|
|
def supported_languages(self) -> list[str]:
|
|
"""Return a list of supported languages."""
|
|
return self._supported_languages
|
|
|
|
@property
|
|
def supported_formats(self) -> list[AudioFormats]:
|
|
"""Return a list of supported formats."""
|
|
return [AudioFormats.WAV, AudioFormats.OGG]
|
|
|
|
@property
|
|
def supported_codecs(self) -> list[AudioCodecs]:
|
|
"""Return a list of supported codecs."""
|
|
return [AudioCodecs.PCM, AudioCodecs.OPUS]
|
|
|
|
@property
|
|
def supported_bit_rates(self) -> list[AudioBitRates]:
|
|
"""Return a list of supported bitrates."""
|
|
return [AudioBitRates.BITRATE_16]
|
|
|
|
@property
|
|
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
|
"""Return a list of supported samplerates."""
|
|
return [AudioSampleRates.SAMPLERATE_16000]
|
|
|
|
@property
|
|
def supported_channels(self) -> list[AudioChannels]:
|
|
"""Return a list of supported channels."""
|
|
return [AudioChannels.CHANNEL_MONO]
|
|
|
|
async def async_process_audio_stream(
|
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
|
) -> SpeechResult:
|
|
"""Process an audio stream."""
|
|
self.calls.append((metadata, stream))
|
|
async for data in stream:
|
|
if not data:
|
|
break
|
|
self.received.append(data)
|
|
if self.fail_process_audio:
|
|
return SpeechResult(None, SpeechResultState.ERROR)
|
|
|
|
return SpeechResult(self.text, SpeechResultState.SUCCESS)
|
|
|
|
|
|
class MockSTTProvider(BaseProvider, Provider):
|
|
"""Mock provider."""
|
|
|
|
url_path = TEST_DOMAIN
|
|
|
|
|
|
class MockSTTProviderEntity(BaseProvider, SpeechToTextEntity):
|
|
"""Mock provider entity."""
|
|
|
|
url_path = "stt.test"
|
|
_attr_name = "test"
|
|
|
|
|
|
class MockSTTPlatform(MockPlatform):
|
|
"""Help to set up test stt service."""
|
|
|
|
def __init__(
|
|
self,
|
|
async_get_engine: Callable[
|
|
[HomeAssistant, ConfigType, DiscoveryInfoType | None],
|
|
Coroutine[Any, Any, Provider | None],
|
|
]
|
|
| None = None,
|
|
get_engine: Callable[
|
|
[HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None
|
|
]
|
|
| None = None,
|
|
) -> None:
|
|
"""Return the stt service."""
|
|
super().__init__()
|
|
if get_engine:
|
|
self.get_engine = get_engine
|
|
if async_get_engine:
|
|
self.async_get_engine = async_get_engine
|
|
|
|
|
|
def mock_stt_platform(
|
|
hass: HomeAssistant,
|
|
tmp_path: Path,
|
|
integration: str = "stt",
|
|
async_get_engine: Callable[
|
|
[HomeAssistant, ConfigType, DiscoveryInfoType | None],
|
|
Coroutine[Any, Any, Provider | None],
|
|
]
|
|
| None = None,
|
|
get_engine: Callable[
|
|
[HomeAssistant, ConfigType, DiscoveryInfoType | None], Provider | None
|
|
]
|
|
| None = None,
|
|
):
|
|
"""Specialize the mock platform for stt."""
|
|
loaded_platform = MockSTTPlatform(async_get_engine, get_engine)
|
|
mock_platform(hass, f"{integration}.stt", loaded_platform)
|
|
|
|
return loaded_platform
|
|
|
|
|
|
def mock_stt_entity_platform(
|
|
hass: HomeAssistant,
|
|
tmp_path: Path,
|
|
integration: str,
|
|
async_setup_entry: Callable[
|
|
[HomeAssistant, ConfigEntry, AddEntitiesCallback],
|
|
Coroutine[Any, Any, None],
|
|
]
|
|
| None = None,
|
|
) -> MockPlatform:
|
|
"""Specialize the mock platform for stt."""
|
|
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry)
|
|
mock_platform(hass, f"{integration}.stt", loaded_platform)
|
|
return loaded_platform
|