core/tests/components/tts/common.py

266 lines
7.7 KiB
Python

"""Provide common tests tools for tts."""
from __future__ import annotations
from collections.abc import Generator
from http import HTTPStatus
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
import voluptuous as vol
from homeassistant.components import media_source
from homeassistant.components.tts import (
CONF_LANG,
DOMAIN as TTS_DOMAIN,
PLATFORM_SCHEMA as TTS_PLATFORM_SCHEMA,
Provider,
TextToSpeechEntity,
TtsAudioType,
Voice,
_get_cache_files,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from homeassistant.setup import async_setup_component
from tests.common import (
MockConfigEntry,
MockModule,
MockPlatform,
mock_integration,
mock_platform,
)
from tests.typing import ClientSessionGenerator
DEFAULT_LANG = "en_US"
SUPPORT_LANGUAGES = ["de_CH", "de_DE", "en_GB", "en_US"]
TEST_DOMAIN = "test"
def mock_tts_get_cache_files_fixture_helper() -> Generator[MagicMock]:
"""Mock the list TTS cache function."""
with patch(
"homeassistant.components.tts._get_cache_files", return_value={}
) as mock_cache_files:
yield mock_cache_files
def mock_tts_init_cache_dir_fixture_helper(
init_tts_cache_dir_side_effect: Any,
) -> Generator[MagicMock]:
"""Mock the TTS cache dir in memory."""
with patch(
"homeassistant.components.tts._init_tts_cache_dir",
side_effect=init_tts_cache_dir_side_effect,
) as mock_cache_dir:
yield mock_cache_dir
def init_tts_cache_dir_side_effect_fixture_helper() -> Any:
"""Return the cache dir."""
return None
def mock_tts_cache_dir_fixture_helper(
tmp_path: Path,
mock_tts_init_cache_dir: MagicMock,
mock_tts_get_cache_files: MagicMock,
request: pytest.FixtureRequest,
) -> Generator[Path]:
"""Mock the TTS cache dir with empty dir."""
mock_tts_init_cache_dir.return_value = str(tmp_path)
# Restore original get cache files behavior, we're working with a real dir.
mock_tts_get_cache_files.side_effect = _get_cache_files
yield tmp_path
if not hasattr(request.node, "rep_call") or request.node.rep_call.passed:
return
# Print contents of dir if failed
print("Content of dir for", request.node.nodeid) # noqa: T201
for fil in tmp_path.iterdir():
print(fil.relative_to(tmp_path)) # noqa: T201
# To show the log.
pytest.fail("Test failed, see log for details")
def tts_mutagen_mock_fixture_helper() -> Generator[MagicMock]:
"""Mock writing tags."""
with patch(
"homeassistant.components.tts.SpeechManager.write_tags",
side_effect=lambda *args: args[1],
) as mock_write_tags:
yield mock_write_tags
async def get_media_source_url(hass: HomeAssistant, media_content_id: str) -> str:
"""Get the media source url."""
if media_source.DOMAIN not in hass.config.components:
assert await async_setup_component(hass, media_source.DOMAIN, {})
resolved = await media_source.async_resolve_media(hass, media_content_id, None)
return resolved.url
async def retrieve_media(
hass: HomeAssistant, hass_client: ClientSessionGenerator, media_content_id: str
) -> HTTPStatus:
"""Get the media source url."""
url = await get_media_source_url(hass, media_content_id)
# Ensure media has been generated by requesting it
await hass.async_block_till_done()
client = await hass_client()
req = await client.get(url)
return req.status
class BaseProvider:
"""Test speech API provider."""
def __init__(self, lang: str) -> None:
"""Initialize test provider."""
self._lang = lang
self._supported_languages = SUPPORT_LANGUAGES
self._supported_options = ["voice", "age"]
@property
def default_language(self) -> str:
"""Return the default language."""
return self._lang
@property
def supported_languages(self) -> list[str]:
"""Return list of supported languages."""
return self._supported_languages
@callback
def async_get_supported_voices(self, language: str) -> list[Voice] | None:
"""Return list of supported languages."""
if language == "en-US":
return [
Voice("james_earl_jones", "James Earl Jones"),
Voice("fran_drescher", "Fran Drescher"),
]
return None
@property
def supported_options(self) -> list[str]:
"""Return list of supported options like voice, emotions."""
return self._supported_options
def get_tts_audio(
self, message: str, language: str, options: dict[str, Any]
) -> TtsAudioType:
"""Load TTS dat."""
return ("mp3", b"")
class MockTTSProvider(BaseProvider, Provider):
"""Test speech API provider."""
def __init__(self, lang: str) -> None:
"""Initialize test provider."""
super().__init__(lang)
self.name = "Test"
class MockTTSEntity(BaseProvider, TextToSpeechEntity):
"""Test speech API provider."""
_attr_name = "Test"
class MockTTS(MockPlatform):
"""A mock TTS platform."""
PLATFORM_SCHEMA = TTS_PLATFORM_SCHEMA.extend(
{vol.Optional(CONF_LANG, default=DEFAULT_LANG): vol.In(SUPPORT_LANGUAGES)}
)
def __init__(self, provider: MockTTSProvider, **kwargs: Any) -> None:
"""Initialize."""
super().__init__(**kwargs)
self._provider = provider
async def async_get_engine(
self,
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> Provider | None:
"""Set up a mock speech component."""
return self._provider
async def mock_setup(
hass: HomeAssistant,
mock_provider: MockTTSProvider,
) -> None:
"""Set up a test provider."""
mock_integration(hass, MockModule(domain=TEST_DOMAIN))
mock_platform(hass, f"{TEST_DOMAIN}.{TTS_DOMAIN}", MockTTS(mock_provider))
await async_setup_component(
hass, TTS_DOMAIN, {TTS_DOMAIN: {"platform": TEST_DOMAIN}}
)
await hass.async_block_till_done()
async def mock_config_entry_setup(
hass: HomeAssistant,
tts_entity: MockTTSEntity,
test_domain: str = TEST_DOMAIN,
) -> MockConfigEntry:
"""Set up a test tts platform via config entry."""
async def async_setup_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Set up test config entry."""
await hass.config_entries.async_forward_entry_setups(config_entry, [TTS_DOMAIN])
return True
async def async_unload_entry_init(
hass: HomeAssistant, config_entry: ConfigEntry
) -> bool:
"""Unload test config entry."""
await hass.config_entries.async_forward_entry_unload(config_entry, TTS_DOMAIN)
return True
mock_integration(
hass,
MockModule(
test_domain,
async_setup_entry=async_setup_entry_init,
async_unload_entry=async_unload_entry_init,
),
)
async def async_setup_entry_platform(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up test tts platform via config entry."""
async_add_entities([tts_entity])
loaded_platform = MockPlatform(async_setup_entry=async_setup_entry_platform)
mock_platform(hass, f"{test_domain}.{TTS_DOMAIN}", loaded_platform)
config_entry = MockConfigEntry(domain=test_domain)
config_entry.add_to_hass(hass)
assert await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
return config_entry