mirror of https://github.com/home-assistant/core
266 lines
7.7 KiB
Python
266 lines
7.7 KiB
Python
"""Provide common tests tools for tts."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Generator
|
|
from http import HTTPStatus
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import voluptuous as vol
|
|
|
|
from homeassistant.components import media_source
|
|
from homeassistant.components.tts import (
|
|
CONF_LANG,
|
|
DOMAIN as TTS_DOMAIN,
|
|
PLATFORM_SCHEMA as TTS_PLATFORM_SCHEMA,
|
|
Provider,
|
|
TextToSpeechEntity,
|
|
TtsAudioType,
|
|
Voice,
|
|
_get_cache_files,
|
|
)
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.core import HomeAssistant, callback
|
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
|
from homeassistant.setup import async_setup_component
|
|
|
|
from tests.common import (
|
|
MockConfigEntry,
|
|
MockModule,
|
|
MockPlatform,
|
|
mock_integration,
|
|
mock_platform,
|
|
)
|
|
from tests.typing import ClientSessionGenerator
|
|
|
|
DEFAULT_LANG = "en_US"
|
|
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
|
|
TEST_DOMAIN = "test"
|
|
|
|
|
|
def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]:
|
|
"""Mock the list TTS cache function."""
|
|
with patch(
|
|
"homeassistant.components.tts._get_cache_files", return_value={}
|
|
) as mock_cache_files:
|
|
yield mock_cache_files
|
|
|
|
|
|
def mock_tts_init_cache_dir_fixture_helper(
|
|
init_tts_cache_dir_side_effect: Any,
|
|
) -> Generator[MagicMock]:
|
|
"""Mock the TTS cache dir in memory."""
|
|
with patch(
|
|
"homeassistant.components.tts._init_tts_cache_dir",
|
|
side_effect=init_tts_cache_dir_side_effect,
|
|
) as mock_cache_dir:
|
|
yield mock_cache_dir
|
|
|
|
|
|
def init_tts_cache_dir_side_effect_fixture_helper() -> Any:
|
|
"""Return the cache dir."""
|
|
return None
|
|
|
|
|
|
def mock_tts_cache_dir_fixture_helper(
|
|
tmp_path: Path,
|
|
mock_tts_init_cache_dir: MagicMock,
|
|
mock_tts_get_cache_files: MagicMock,
|
|
request: pytest.FixtureRequest,
|
|
) -> Generator[Path]:
|
|
"""Mock the TTS cache dir with empty dir."""
|
|
mock_tts_init_cache_dir.return_value = str(tmp_path)
|
|
|
|
# Restore original get cache files behavior, we're working with a real dir.
|
|
mock_tts_get_cache_files.side_effect = _get_cache_files
|
|
|
|
yield tmp_path
|
|
|
|
if not hasattr(request.node, "rep_call") or request.node.rep_call.passed:
|
|
return
|
|
|
|
# Print contents of dir if failed
|
|
print("Content of dir for", request.node.nodeid) # noqa: T201
|
|
for fil in tmp_path.iterdir():
|
|
print(fil.relative_to(tmp_path)) # noqa: T201
|
|
|
|
# To show the log.
|
|
pytest.fail("Test failed, see log for details")
|
|
|
|
|
|
def tts_mutagen_mock_fixture_helper() -> Generator[MagicMock]:
|
|
"""Mock writing tags."""
|
|
with patch(
|
|
"homeassistant.components.tts.SpeechManager.write_tags",
|
|
side_effect=lambda *args: args[1],
|
|
) as mock_write_tags:
|
|
yield mock_write_tags
|
|
|
|
|
|
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
|
|
"""Get the media source url."""
|
|
if media_source.DOMAIN not in hass.config.components:
|
|
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
|
|
|
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
|
|
return resolved.url
|
|
|
|
|
|
async def retrieve_media(
|
|
hass: HomeAssistant, hass_client: ClientSessionGenerator, media_content_id: str
|
|
) -> HTTPStatus:
|
|
"""Get the media source url."""
|
|
url = await get_media_source_url(hass, media_content_id)
|
|
|
|
# Ensure media has been generated by requesting it
|
|
await hass.async_block_till_done()
|
|
client = await hass_client()
|
|
req = await client.get(url)
|
|
|
|
return req.status
|
|
|
|
|
|
class BaseProvider:
|
|
"""Test speech API provider."""
|
|
|
|
def __init__(self, lang: str) -> None:
|
|
"""Initialize test provider."""
|
|
self._lang = lang
|
|
self._supported_languages = SUPPORT_LANGUAGES
|
|
self._supported_options = ["voice", "age"]
|
|
|
|
@property
|
|
def default_language(self) -> str:
|
|
"""Return the default language."""
|
|
return self._lang
|
|
|
|
@property
|
|
def supported_languages(self) -> list[str]:
|
|
"""Return list of supported languages."""
|
|
return self._supported_languages
|
|
|
|
@callback
|
|
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
|
|
"""Return list of supported languages."""
|
|
if language == "en-US":
|
|
return [
|
|
Voice("james_earl_jones", "James Earl Jones"),
|
|
Voice("fran_drescher", "Fran Drescher"),
|
|
]
|
|
return None
|
|
|
|
@property
|
|
def supported_options(self) -> list[str]:
|
|
"""Return list of supported options like voice, emotions."""
|
|
return self._supported_options
|
|
|
|
def get_tts_audio(
|
|
self, message: str, language: str, options: dict[str, Any]
|
|
) -> TtsAudioType:
|
|
"""Load TTS dat."""
|
|
return ("mp3", b"")
|
|
|
|
|
|
class MockTTSProvider(BaseProvider, Provider):
|
|
"""Test speech API provider."""
|
|
|
|
def __init__(self, lang: str) -> None:
|
|
"""Initialize test provider."""
|
|
super().__init__(lang)
|
|
self.name = "Test"
|
|
|
|
|
|
class MockTTSEntity(BaseProvider, TextToSpeechEntity):
|
|
"""Test speech API provider."""
|
|
|
|
_attr_name = "Test"
|
|
|
|
|
|
class MockTTS(MockPlatform):
|
|
"""A mock TTS platform."""
|
|
|
|
PLATFORM_SCHEMA = TTS_PLATFORM_SCHEMA.extend(
|
|
{vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
|
|
)
|
|
|
|
def __init__(self, provider: MockTTSProvider, **kwargs: Any) -> None:
|
|
"""Initialize."""
|
|
super().__init__(**kwargs)
|
|
self._provider = provider
|
|
|
|
async def async_get_engine(
|
|
self,
|
|
hass: HomeAssistant,
|
|
config: ConfigType,
|
|
discovery_info: DiscoveryInfoType | None = None,
|
|
) -> Provider | None:
|
|
"""Set up a mock speech component."""
|
|
return self._provider
|
|
|
|
|
|
async def mock_setup(
|
|
hass: HomeAssistant,
|
|
mock_provider: MockTTSProvider,
|
|
) -> None:
|
|
"""Set up a test provider."""
|
|
mock_integration(hass, MockModule(domain=TEST_DOMAIN))
|
|
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", MockTTS(mock_provider))
|
|
|
|
await async_setup_component(
|
|
hass, TTS_DOMAIN, {TTS_DOMAIN: {"platform": TEST_DOMAIN}}
|
|
)
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
async def mock_config_entry_setup(
|
|
hass: HomeAssistant,
|
|
tts_entity: MockTTSEntity,
|
|
test_domain: str = TEST_DOMAIN,
|
|
) -> MockConfigEntry:
|
|
"""Set up a test tts platform via config entry."""
|
|
|
|
async def async_setup_entry_init(
|
|
hass: HomeAssistant, config_entry: ConfigEntry
|
|
) -> bool:
|
|
"""Set up test config entry."""
|
|
await hass.config_entries.async_forward_entry_setups(config_entry, [TTS_DOMAIN])
|
|
return True
|
|
|
|
async def async_unload_entry_init(
|
|
hass: HomeAssistant, config_entry: ConfigEntry
|
|
) -> bool:
|
|
"""Unload test config entry."""
|
|
await hass.config_entries.async_forward_entry_unload(config_entry, TTS_DOMAIN)
|
|
return True
|
|
|
|
mock_integration(
|
|
hass,
|
|
MockModule(
|
|
test_domain,
|
|
async_setup_entry=async_setup_entry_init,
|
|
async_unload_entry=async_unload_entry_init,
|
|
),
|
|
)
|
|
|
|
async def async_setup_entry_platform(
|
|
hass: HomeAssistant,
|
|
config_entry: ConfigEntry,
|
|
async_add_entities: AddEntitiesCallback,
|
|
) -> None:
|
|
"""Set up test tts platform via config entry."""
|
|
async_add_entities([tts_entity])
|
|
|
|
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
|
|
mock_platform(hass, f"{test_domain}.{TTS_DOMAIN}", loaded_platform)
|
|
|
|
config_entry = MockConfigEntry(domain=test_domain)
|
|
config_entry.add_to_hass(hass)
|
|
assert await hass.config_entries.async_setup(config_entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
return config_entry
|