core/homeassistant/components/cloud/stt.py

127 lines
4.1 KiB
Python

"""Support for the cloud for speech to text service."""
from __future__ import annotations
from collections.abc import AsyncIterable
import logging
from hass_nabucasa import Cloud
from hass_nabucasa.voice import STT_LANGUAGES, VoiceError
from homeassistant.components.stt import (
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
SpeechMetadata,
SpeechResult,
SpeechResultState,
SpeechToTextEntity,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.setup import async_when_setup
from .assist_pipeline import async_migrate_cloud_pipeline_engine
from .client import CloudClient
from .const import DATA_CLOUD, DATA_PLATFORMS_SETUP, STT_ENTITY_UNIQUE_ID
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Home Assistant Cloud speech platform via config entry."""
stt_platform_loaded = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
stt_platform_loaded.set()
cloud = hass.data[DATA_CLOUD]
async_add_entities([CloudProviderEntity(cloud)])
class CloudProviderEntity(SpeechToTextEntity):
"""Home Assistant Cloud speech API provider."""
_attr_name = "Home Assistant Cloud"
_attr_unique_id = STT_ENTITY_UNIQUE_ID
def __init__(self, cloud: Cloud[CloudClient]) -> None:
"""Initialize cloud Speech to text entity."""
self.cloud = cloud
@property
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
return STT_LANGUAGES
@property
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
return [AudioFormats.WAV, AudioFormats.OGG]
@property
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
return [AudioCodecs.PCM, AudioCodecs.OPUS]
@property
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bitrates."""
return [AudioBitRates.BITRATE_16]
@property
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported samplerates."""
return [AudioSampleRates.SAMPLERATE_16000]
@property
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
return [AudioChannels.CHANNEL_MONO]
async def async_added_to_hass(self) -> None:
"""Run when entity is about to be added to hass."""
async def pipeline_setup(hass: HomeAssistant, _comp: str) -> None:
"""When assist_pipeline is set up."""
assert self.platform.config_entry
self.platform.config_entry.async_create_task(
hass,
async_migrate_cloud_pipeline_engine(
self.hass, platform=Platform.STT, engine_id=self.entity_id
),
)
async_when_setup(self.hass, "assist_pipeline", pipeline_setup)
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service."""
content_type = (
f"audio/{metadata.format!s}; codecs=audio/{metadata.codec!s};"
" samplerate=16000"
)
# Process STT
try:
result = await self.cloud.voice.process_stt(
stream=stream,
content_type=content_type,
language=metadata.language,
)
except VoiceError as err:
_LOGGER.error("Voice error: %s", err)
return SpeechResult(None, SpeechResultState.ERROR)
# Return Speech as Text
return SpeechResult(
result.text,
SpeechResultState.SUCCESS if result.success else SpeechResultState.ERROR,
)