mirror of https://github.com/home-assistant/core
148 lines
5.0 KiB
Python
148 lines
5.0 KiB
Python
"""Support for the Google Cloud STT service."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncGenerator, AsyncIterable
|
|
import logging
|
|
|
|
from google.api_core.exceptions import GoogleAPIError, Unauthenticated
|
|
from google.cloud import speech_v1
|
|
|
|
from homeassistant.components.stt import (
|
|
AudioBitRates,
|
|
AudioChannels,
|
|
AudioCodecs,
|
|
AudioFormats,
|
|
AudioSampleRates,
|
|
SpeechMetadata,
|
|
SpeechResult,
|
|
SpeechResultState,
|
|
SpeechToTextEntity,
|
|
)
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.helpers import device_registry as dr
|
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|
|
|
from .const import (
|
|
CONF_SERVICE_ACCOUNT_INFO,
|
|
CONF_STT_MODEL,
|
|
DEFAULT_STT_MODEL,
|
|
DOMAIN,
|
|
STT_LANGUAGES,
|
|
)
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
async def async_setup_entry(
|
|
hass: HomeAssistant,
|
|
config_entry: ConfigEntry,
|
|
async_add_entities: AddEntitiesCallback,
|
|
) -> None:
|
|
"""Set up Google Cloud speech platform via config entry."""
|
|
service_account_info = config_entry.data[CONF_SERVICE_ACCOUNT_INFO]
|
|
client = speech_v1.SpeechAsyncClient.from_service_account_info(service_account_info)
|
|
async_add_entities([GoogleCloudSpeechToTextEntity(config_entry, client)])
|
|
|
|
|
|
class GoogleCloudSpeechToTextEntity(SpeechToTextEntity):
|
|
"""Google Cloud STT entity."""
|
|
|
|
def __init__(
|
|
self,
|
|
entry: ConfigEntry,
|
|
client: speech_v1.SpeechAsyncClient,
|
|
) -> None:
|
|
"""Init Google Cloud STT entity."""
|
|
self._attr_unique_id = f"{entry.entry_id}"
|
|
self._attr_name = entry.title
|
|
self._attr_device_info = dr.DeviceInfo(
|
|
identifiers={(DOMAIN, entry.entry_id)},
|
|
manufacturer="Google",
|
|
model="Cloud",
|
|
entry_type=dr.DeviceEntryType.SERVICE,
|
|
)
|
|
self._entry = entry
|
|
self._client = client
|
|
self._model = entry.options.get(CONF_STT_MODEL, DEFAULT_STT_MODEL)
|
|
|
|
@property
|
|
def supported_languages(self) -> list[str]:
|
|
"""Return a list of supported languages."""
|
|
return STT_LANGUAGES
|
|
|
|
@property
|
|
def supported_formats(self) -> list[AudioFormats]:
|
|
"""Return a list of supported formats."""
|
|
return [AudioFormats.WAV, AudioFormats.OGG]
|
|
|
|
@property
|
|
def supported_codecs(self) -> list[AudioCodecs]:
|
|
"""Return a list of supported codecs."""
|
|
return [AudioCodecs.PCM, AudioCodecs.OPUS]
|
|
|
|
@property
|
|
def supported_bit_rates(self) -> list[AudioBitRates]:
|
|
"""Return a list of supported bitrates."""
|
|
return [AudioBitRates.BITRATE_16]
|
|
|
|
@property
|
|
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
|
"""Return a list of supported samplerates."""
|
|
return [AudioSampleRates.SAMPLERATE_16000]
|
|
|
|
@property
|
|
def supported_channels(self) -> list[AudioChannels]:
|
|
"""Return a list of supported channels."""
|
|
return [AudioChannels.CHANNEL_MONO]
|
|
|
|
async def async_process_audio_stream(
|
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
|
) -> SpeechResult:
|
|
"""Process an audio stream to STT service."""
|
|
streaming_config = speech_v1.StreamingRecognitionConfig(
|
|
config=speech_v1.RecognitionConfig(
|
|
encoding=(
|
|
speech_v1.RecognitionConfig.AudioEncoding.OGG_OPUS
|
|
if metadata.codec == AudioCodecs.OPUS
|
|
else speech_v1.RecognitionConfig.AudioEncoding.LINEAR16
|
|
),
|
|
sample_rate_hertz=metadata.sample_rate,
|
|
language_code=metadata.language,
|
|
model=self._model,
|
|
)
|
|
)
|
|
|
|
async def request_generator() -> (
|
|
AsyncGenerator[speech_v1.StreamingRecognizeRequest]
|
|
):
|
|
# The first request must only contain a streaming_config
|
|
yield speech_v1.StreamingRecognizeRequest(streaming_config=streaming_config)
|
|
# All subsequent requests must only contain audio_content
|
|
async for audio_content in stream:
|
|
yield speech_v1.StreamingRecognizeRequest(audio_content=audio_content)
|
|
|
|
try:
|
|
responses = await self._client.streaming_recognize(
|
|
requests=request_generator(),
|
|
timeout=10,
|
|
)
|
|
|
|
transcript = ""
|
|
async for response in responses:
|
|
_LOGGER.debug("response: %s", response)
|
|
if not response.results:
|
|
continue
|
|
result = response.results[0]
|
|
if not result.alternatives:
|
|
continue
|
|
transcript += response.results[0].alternatives[0].transcript
|
|
except GoogleAPIError as err:
|
|
_LOGGER.error("Error occurred during Google Cloud STT call: %s", err)
|
|
if isinstance(err, Unauthenticated):
|
|
self._entry.async_start_reauth(self.hass)
|
|
return SpeechResult(None, SpeechResultState.ERROR)
|
|
|
|
return SpeechResult(transcript, SpeechResultState.SUCCESS)
|