mirror of https://github.com/home-assistant/core
847 lines
26 KiB
Python
847 lines
26 KiB
Python
"""Test VoIP protocol."""
|
|
|
|
import asyncio
|
|
import io
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
import wave
|
|
|
|
import pytest
|
|
from syrupy.assertion import SnapshotAssertion
|
|
from voip_utils import CallInfo
|
|
|
|
from homeassistant.components import assist_pipeline, assist_satellite, tts, voip
|
|
from homeassistant.components.assist_satellite import AssistSatelliteEntity
|
|
|
|
# pylint: disable-next=hass-component-root-import
|
|
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
|
|
from homeassistant.components.voip import HassVoipDatagramProtocol
|
|
from homeassistant.components.voip.assist_satellite import Tones, VoipAssistSatellite
|
|
from homeassistant.components.voip.devices import VoIPDevice, VoIPDevices
|
|
from homeassistant.components.voip.voip import PreRecordMessageProtocol, make_protocol
|
|
from homeassistant.const import STATE_OFF, STATE_ON, Platform
|
|
from homeassistant.core import Context, HomeAssistant
|
|
from homeassistant.helpers import entity_registry as er
|
|
from homeassistant.helpers.entity_component import EntityComponent
|
|
from homeassistant.setup import async_setup_component
|
|
|
|
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
|
_MEDIA_ID = "12345"
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
|
|
"""Mock the TTS cache dir with empty dir."""
|
|
|
|
|
|
def _empty_wav() -> bytes:
|
|
"""Return bytes of an empty WAV file."""
|
|
with io.BytesIO() as wav_io:
|
|
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
|
with wav_file:
|
|
wav_file.setframerate(16000)
|
|
wav_file.setsampwidth(2)
|
|
wav_file.setnchannels(1)
|
|
|
|
return wav_io.getvalue()
|
|
|
|
|
|
def async_get_satellite_entity(
|
|
hass: HomeAssistant, domain: str, unique_id_prefix: str
|
|
) -> AssistSatelliteEntity | None:
|
|
"""Get Assist satellite entity."""
|
|
ent_reg = er.async_get(hass)
|
|
satellite_entity_id = ent_reg.async_get_entity_id(
|
|
Platform.ASSIST_SATELLITE, domain, f"{unique_id_prefix}-assist_satellite"
|
|
)
|
|
if satellite_entity_id is None:
|
|
return None
|
|
assert not satellite_entity_id.endswith("none")
|
|
|
|
component: EntityComponent[AssistSatelliteEntity] = hass.data[
|
|
assist_satellite.DOMAIN
|
|
]
|
|
return component.get_entity(satellite_entity_id)
|
|
|
|
|
|
async def test_is_valid_call(
|
|
hass: HomeAssistant,
|
|
voip_devices: VoIPDevices,
|
|
voip_device: VoIPDevice,
|
|
call_info: CallInfo,
|
|
) -> None:
|
|
"""Test that a call is now allowed from an unknown device."""
|
|
assert await async_setup_component(hass, "voip", {})
|
|
protocol = HassVoipDatagramProtocol(hass, voip_devices)
|
|
assert not protocol.is_valid_call(call_info)
|
|
|
|
ent_reg = er.async_get(hass)
|
|
allowed_call_entity_id = ent_reg.async_get_entity_id(
|
|
"switch", voip.DOMAIN, f"{voip_device.voip_id}-allow_call"
|
|
)
|
|
assert allowed_call_entity_id is not None
|
|
state = hass.states.get(allowed_call_entity_id)
|
|
assert state is not None
|
|
assert state.state == STATE_OFF
|
|
|
|
# Allow calls
|
|
hass.states.async_set(allowed_call_entity_id, STATE_ON)
|
|
assert protocol.is_valid_call(call_info)
|
|
|
|
|
|
async def test_calls_not_allowed(
|
|
hass: HomeAssistant,
|
|
voip_devices: VoIPDevices,
|
|
voip_device: VoIPDevice,
|
|
call_info: CallInfo,
|
|
snapshot: SnapshotAssertion,
|
|
) -> None:
|
|
"""Test that a pre-recorded message is played when calls aren't allowed."""
|
|
assert await async_setup_component(hass, "voip", {})
|
|
protocol: PreRecordMessageProtocol = make_protocol(hass, voip_devices, call_info)
|
|
assert isinstance(protocol, PreRecordMessageProtocol)
|
|
assert protocol.file_name == "problem.pcm"
|
|
|
|
# Test the playback
|
|
done = asyncio.Event()
|
|
played_audio_bytes = b""
|
|
|
|
def send_audio(audio_bytes: bytes, **kwargs):
|
|
nonlocal played_audio_bytes
|
|
|
|
# Should be problem.pcm from components/voip
|
|
played_audio_bytes = audio_bytes
|
|
done.set()
|
|
|
|
protocol.transport = Mock()
|
|
protocol.loop_delay = 0
|
|
with patch.object(protocol, "send_audio", send_audio):
|
|
protocol.on_chunk(bytes(_ONE_SECOND))
|
|
|
|
async with asyncio.timeout(1):
|
|
await done.wait()
|
|
|
|
assert sum(played_audio_bytes) > 0
|
|
assert played_audio_bytes == snapshot()
|
|
|
|
|
|
async def test_pipeline_not_found(
|
|
hass: HomeAssistant,
|
|
voip_devices: VoIPDevices,
|
|
voip_device: VoIPDevice,
|
|
call_info: CallInfo,
|
|
snapshot: SnapshotAssertion,
|
|
) -> None:
|
|
"""Test that a pre-recorded message is played when a pipeline isn't found."""
|
|
assert await async_setup_component(hass, "voip", {})
|
|
|
|
with patch(
|
|
"homeassistant.components.voip.voip.async_get_pipeline", return_value=None
|
|
):
|
|
protocol: PreRecordMessageProtocol = make_protocol(
|
|
hass, voip_devices, call_info
|
|
)
|
|
|
|
assert isinstance(protocol, PreRecordMessageProtocol)
|
|
assert protocol.file_name == "problem.pcm"
|
|
|
|
|
|
async def test_satellite_prepared(
|
|
hass: HomeAssistant,
|
|
voip_devices: VoIPDevices,
|
|
voip_device: VoIPDevice,
|
|
call_info: CallInfo,
|
|
snapshot: SnapshotAssertion,
|
|
) -> None:
|
|
"""Test that satellite is prepared for a call."""
|
|
assert await async_setup_component(hass, "voip", {})
|
|
|
|
pipeline = assist_pipeline.Pipeline(
|
|
conversation_engine="test",
|
|
conversation_language="en",
|
|
language="en",
|
|
name="test",
|
|
stt_engine="test",
|
|
stt_language="en",
|
|
tts_engine="test",
|
|
tts_language="en",
|
|
tts_voice=None,
|
|
wake_word_entity=None,
|
|
wake_word_id=None,
|
|
)
|
|
|
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
|
assert isinstance(satellite, VoipAssistSatellite)
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.voip.voip.async_get_pipeline",
|
|
return_value=pipeline,
|
|
),
|
|
):
|
|
protocol = make_protocol(hass, voip_devices, call_info)
|
|
assert protocol == satellite
|
|
|
|
|
|
async def test_pipeline(
|
|
hass: HomeAssistant,
|
|
voip_devices: VoIPDevices,
|
|
voip_device: VoIPDevice,
|
|
call_info: CallInfo,
|
|
) -> None:
|
|
"""Test that pipeline function is called from RTP protocol."""
|
|
assert await async_setup_component(hass, "voip", {})
|
|
|
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
|
assert isinstance(satellite, VoipAssistSatellite)
|
|
voip_user_id = satellite.config_entry.data["user"]
|
|
assert voip_user_id
|
|
|
|
# Satellite is muted until a call begins
|
|
assert satellite.state == AssistSatelliteState.IDLE
|
|
|
|
done = asyncio.Event()
|
|
|
|
# Used to test that audio queue is cleared before pipeline starts
|
|
bad_chunk = bytes([1, 2, 3, 4])
|
|
|
|
async def async_pipeline_from_audio_stream(
|
|
hass: HomeAssistant,
|
|
context: Context,
|
|
*args,
|
|
device_id: str | None,
|
|
tts_audio_output: str | dict[str, Any] | None,
|
|
**kwargs,
|
|
):
|
|
assert context.user_id == voip_user_id
|
|
assert device_id == voip_device.device_id
|
|
|
|
# voip can only stream WAV
|
|
assert tts_audio_output == {
|
|
tts.ATTR_PREFERRED_FORMAT: "wav",
|
|
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
|
|
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
|
|
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
|
}
|
|
|
|
stt_stream = kwargs["stt_stream"]
|
|
event_callback = kwargs["event_callback"]
|
|
in_command = False
|
|
async for chunk in stt_stream:
|
|
# Stream will end when VAD detects end of "speech"
|
|
assert chunk != bad_chunk
|
|
if sum(chunk) > 0:
|
|
in_command = True
|
|
elif in_command:
|
|
break # done with command
|
|
|
|
# Test empty data
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type="not-used",
|
|
data={},
|
|
)
|
|
)
|
|
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.STT_START,
|
|
data={"engine": "test", "metadata": {}},
|
|
)
|
|
)
|
|
|
|
assert satellite.state == AssistSatelliteState.LISTENING
|
|
|
|
# Fake STT result
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.STT_END,
|
|
data={"stt_output": {"text": "fake-text"}},
|
|
)
|
|
)
|
|
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.INTENT_START,
|
|
data={
|
|
"engine": "test",
|
|
"language": hass.config.language,
|
|
"intent_input": "fake-text",
|
|
"conversation_id": None,
|
|
"device_id": None,
|
|
},
|
|
)
|
|
)
|
|
|
|
assert satellite.state == AssistSatelliteState.PROCESSING
|
|
|
|
# Fake intent result
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.INTENT_END,
|
|
data={
|
|
"intent_output": {
|
|
"conversation_id": "fake-conversation",
|
|
}
|
|
},
|
|
)
|
|
)
|
|
|
|
# Fake tts result
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.TTS_START,
|
|
data={
|
|
"engine": "test",
|
|
"language": hass.config.language,
|
|
"voice": "test",
|
|
"tts_input": "fake-text",
|
|
},
|
|
)
|
|
)
|
|
|
|
assert satellite.state == AssistSatelliteState.RESPONDING
|
|
|
|
# Proceed with media output
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
|
data={"tts_output": {"media_id": _MEDIA_ID}},
|
|
)
|
|
)
|
|
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.RUN_END
|
|
)
|
|
)
|
|
|
|
original_tts_response_finished = satellite.tts_response_finished
|
|
|
|
def tts_response_finished():
|
|
original_tts_response_finished()
|
|
done.set()
|
|
|
|
async def async_get_media_source_audio(
|
|
hass: HomeAssistant,
|
|
media_source_id: str,
|
|
) -> tuple[str, bytes]:
|
|
assert media_source_id == _MEDIA_ID
|
|
return ("wav", _empty_wav())
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
new=async_pipeline_from_audio_stream,
|
|
),
|
|
patch(
|
|
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
|
new=async_get_media_source_audio,
|
|
),
|
|
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
|
):
|
|
satellite._tones = Tones(0)
|
|
satellite.transport = Mock()
|
|
|
|
satellite.connection_made(satellite.transport)
|
|
assert satellite.state == AssistSatelliteState.IDLE
|
|
|
|
# Ensure audio queue is cleared before pipeline starts
|
|
satellite._audio_queue.put_nowait(bad_chunk)
|
|
|
|
def send_audio(*args, **kwargs):
|
|
# Don't send audio
|
|
pass
|
|
|
|
satellite.send_audio = Mock(side_effect=send_audio)
|
|
|
|
# silence
|
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
|
|
|
# "speech"
|
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
|
|
|
# silence
|
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
|
|
|
# Wait for mock pipeline to exhaust the audio stream
|
|
async with asyncio.timeout(1):
|
|
await done.wait()
|
|
|
|
# Finished speaking
|
|
assert satellite.state == AssistSatelliteState.IDLE
|
|
|
|
|
|
async def test_stt_stream_timeout(
|
|
hass: HomeAssistant, voip_devices: VoIPDevices, voip_device: VoIPDevice
|
|
) -> None:
|
|
"""Test timeout in STT stream during pipeline run."""
|
|
assert await async_setup_component(hass, "voip", {})
|
|
|
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
|
assert isinstance(satellite, VoipAssistSatellite)
|
|
|
|
done = asyncio.Event()
|
|
|
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
stt_stream = kwargs["stt_stream"]
|
|
async for _chunk in stt_stream:
|
|
# Iterate over stream
|
|
pass
|
|
|
|
with patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
new=async_pipeline_from_audio_stream,
|
|
):
|
|
satellite._tones = Tones(0)
|
|
satellite._audio_chunk_timeout = 0.001
|
|
transport = Mock(spec=["close"])
|
|
satellite.connection_made(transport)
|
|
|
|
# Closing the transport will cause the test to succeed
|
|
transport.close.side_effect = done.set
|
|
|
|
# silence
|
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
|
|
|
# Wait for mock pipeline to time out
|
|
async with asyncio.timeout(1):
|
|
await done.wait()
|
|
|
|
|
|
async def test_tts_timeout(
|
|
hass: HomeAssistant,
|
|
voip_devices: VoIPDevices,
|
|
voip_device: VoIPDevice,
|
|
) -> None:
|
|
"""Test that TTS will time out based on its length."""
|
|
assert await async_setup_component(hass, "voip", {})
|
|
|
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
|
assert isinstance(satellite, VoipAssistSatellite)
|
|
|
|
done = asyncio.Event()
|
|
|
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
stt_stream = kwargs["stt_stream"]
|
|
event_callback = kwargs["event_callback"]
|
|
in_command = False
|
|
async for chunk in stt_stream:
|
|
if sum(chunk) > 0:
|
|
in_command = True
|
|
elif in_command:
|
|
break # done with command
|
|
|
|
# Fake STT result
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.STT_END,
|
|
data={"stt_output": {"text": "fake-text"}},
|
|
)
|
|
)
|
|
|
|
# Fake intent result
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.INTENT_END,
|
|
data={
|
|
"intent_output": {
|
|
"conversation_id": "fake-conversation",
|
|
}
|
|
},
|
|
)
|
|
)
|
|
|
|
# Proceed with media output
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
|
data={"tts_output": {"media_id": _MEDIA_ID}},
|
|
)
|
|
)
|
|
|
|
tone_bytes = bytes([1, 2, 3, 4])
|
|
|
|
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
|
if audio_bytes == tone_bytes:
|
|
# Not TTS
|
|
return
|
|
|
|
# Block here to force a timeout in _send_tts
|
|
await asyncio.sleep(2)
|
|
|
|
async def async_get_media_source_audio(
|
|
hass: HomeAssistant,
|
|
media_source_id: str,
|
|
) -> tuple[str, bytes]:
|
|
# Should time out immediately
|
|
return ("wav", _empty_wav())
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
new=async_pipeline_from_audio_stream,
|
|
),
|
|
patch(
|
|
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
|
new=async_get_media_source_audio,
|
|
),
|
|
):
|
|
satellite._tts_extra_timeout = 0.001
|
|
for tone in Tones:
|
|
satellite._tone_bytes[tone] = tone_bytes
|
|
|
|
satellite.transport = Mock()
|
|
satellite.send_audio = Mock()
|
|
|
|
original_send_tts = satellite._send_tts
|
|
|
|
async def send_tts(*args, **kwargs):
|
|
# Call original then end test successfully
|
|
with pytest.raises(TimeoutError):
|
|
await original_send_tts(*args, **kwargs)
|
|
|
|
done.set()
|
|
|
|
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
|
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
|
|
|
# silence
|
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
|
|
|
# "speech"
|
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
|
|
|
# silence
|
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
|
|
|
# Wait for mock pipeline to exhaust the audio stream
|
|
async with asyncio.timeout(1):
|
|
await done.wait()
|
|
|
|
|
|
async def test_tts_wrong_extension(
|
|
hass: HomeAssistant,
|
|
voip_devices: VoIPDevices,
|
|
voip_device: VoIPDevice,
|
|
) -> None:
|
|
"""Test that TTS will only stream WAV audio."""
|
|
assert await async_setup_component(hass, "voip", {})
|
|
|
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
|
assert isinstance(satellite, VoipAssistSatellite)
|
|
|
|
done = asyncio.Event()
|
|
|
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
stt_stream = kwargs["stt_stream"]
|
|
event_callback = kwargs["event_callback"]
|
|
in_command = False
|
|
async for chunk in stt_stream:
|
|
if sum(chunk) > 0:
|
|
in_command = True
|
|
elif in_command:
|
|
break # done with command
|
|
|
|
# Fake STT result
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.STT_END,
|
|
data={"stt_output": {"text": "fake-text"}},
|
|
)
|
|
)
|
|
|
|
# Fake intent result
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.INTENT_END,
|
|
data={
|
|
"intent_output": {
|
|
"conversation_id": "fake-conversation",
|
|
}
|
|
},
|
|
)
|
|
)
|
|
|
|
# Proceed with media output
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
|
data={"tts_output": {"media_id": _MEDIA_ID}},
|
|
)
|
|
)
|
|
|
|
async def async_get_media_source_audio(
|
|
hass: HomeAssistant,
|
|
media_source_id: str,
|
|
) -> tuple[str, bytes]:
|
|
# Should fail because it's not "wav"
|
|
return ("mp3", b"")
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
new=async_pipeline_from_audio_stream,
|
|
),
|
|
patch(
|
|
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
|
new=async_get_media_source_audio,
|
|
),
|
|
):
|
|
satellite.transport = Mock()
|
|
|
|
original_send_tts = satellite._send_tts
|
|
|
|
async def send_tts(*args, **kwargs):
|
|
# Call original then end test successfully
|
|
with pytest.raises(ValueError):
|
|
await original_send_tts(*args, **kwargs)
|
|
|
|
done.set()
|
|
|
|
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
|
|
|
# silence
|
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
|
|
|
# "speech"
|
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
|
|
|
# silence (assumes relaxed VAD sensitivity)
|
|
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
|
|
|
# Wait for mock pipeline to exhaust the audio stream
|
|
async with asyncio.timeout(1):
|
|
await done.wait()
|
|
|
|
|
|
async def test_tts_wrong_wav_format(
|
|
hass: HomeAssistant,
|
|
voip_devices: VoIPDevices,
|
|
voip_device: VoIPDevice,
|
|
) -> None:
|
|
"""Test that TTS will only stream WAV audio with a specific format."""
|
|
assert await async_setup_component(hass, "voip", {})
|
|
|
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
|
assert isinstance(satellite, VoipAssistSatellite)
|
|
|
|
done = asyncio.Event()
|
|
|
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
stt_stream = kwargs["stt_stream"]
|
|
event_callback = kwargs["event_callback"]
|
|
in_command = False
|
|
async for chunk in stt_stream:
|
|
if sum(chunk) > 0:
|
|
in_command = True
|
|
elif in_command:
|
|
break # done with command
|
|
|
|
# Fake STT result
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.STT_END,
|
|
data={"stt_output": {"text": "fake-text"}},
|
|
)
|
|
)
|
|
|
|
# Fake intent result
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.INTENT_END,
|
|
data={
|
|
"intent_output": {
|
|
"conversation_id": "fake-conversation",
|
|
}
|
|
},
|
|
)
|
|
)
|
|
|
|
# Proceed with media output
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
|
data={"tts_output": {"media_id": _MEDIA_ID}},
|
|
)
|
|
)
|
|
|
|
async def async_get_media_source_audio(
|
|
hass: HomeAssistant,
|
|
media_source_id: str,
|
|
) -> tuple[str, bytes]:
|
|
# Should fail because it's not 16Khz, 16-bit mono
|
|
with io.BytesIO() as wav_io:
|
|
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
|
with wav_file:
|
|
wav_file.setframerate(22050)
|
|
wav_file.setsampwidth(2)
|
|
wav_file.setnchannels(2)
|
|
|
|
return ("wav", wav_io.getvalue())
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
new=async_pipeline_from_audio_stream,
|
|
),
|
|
patch(
|
|
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
|
new=async_get_media_source_audio,
|
|
),
|
|
):
|
|
satellite.transport = Mock()
|
|
|
|
original_send_tts = satellite._send_tts
|
|
|
|
async def send_tts(*args, **kwargs):
|
|
# Call original then end test successfully
|
|
with pytest.raises(ValueError):
|
|
await original_send_tts(*args, **kwargs)
|
|
|
|
done.set()
|
|
|
|
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
|
|
|
# silence
|
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
|
|
|
# "speech"
|
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
|
|
|
# silence (assumes relaxed VAD sensitivity)
|
|
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
|
|
|
# Wait for mock pipeline to exhaust the audio stream
|
|
async with asyncio.timeout(1):
|
|
await done.wait()
|
|
|
|
|
|
async def test_empty_tts_output(
|
|
hass: HomeAssistant,
|
|
voip_devices: VoIPDevices,
|
|
voip_device: VoIPDevice,
|
|
) -> None:
|
|
"""Test that TTS will not stream when output is empty."""
|
|
assert await async_setup_component(hass, "voip", {})
|
|
|
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
|
assert isinstance(satellite, VoipAssistSatellite)
|
|
|
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
stt_stream = kwargs["stt_stream"]
|
|
event_callback = kwargs["event_callback"]
|
|
in_command = False
|
|
async for chunk in stt_stream:
|
|
if sum(chunk) > 0:
|
|
in_command = True
|
|
elif in_command:
|
|
break # done with command
|
|
|
|
# Fake STT result
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.STT_END,
|
|
data={"stt_output": {"text": "fake-text"}},
|
|
)
|
|
)
|
|
|
|
# Fake intent result
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.INTENT_END,
|
|
data={
|
|
"intent_output": {
|
|
"conversation_id": "fake-conversation",
|
|
}
|
|
},
|
|
)
|
|
)
|
|
|
|
# Empty TTS output
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
|
data={"tts_output": {}},
|
|
)
|
|
)
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
new=async_pipeline_from_audio_stream,
|
|
),
|
|
patch(
|
|
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
|
) as mock_send_tts,
|
|
):
|
|
satellite.transport = Mock()
|
|
|
|
# silence
|
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
|
|
|
# "speech"
|
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
|
|
|
# silence (assumes relaxed VAD sensitivity)
|
|
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
|
|
|
# Wait for mock pipeline to finish
|
|
async with asyncio.timeout(1):
|
|
await satellite._tts_done.wait()
|
|
|
|
mock_send_tts.assert_not_called()
|
|
|
|
|
|
async def test_pipeline_error(
|
|
hass: HomeAssistant,
|
|
voip_devices: VoIPDevices,
|
|
voip_device: VoIPDevice,
|
|
snapshot: SnapshotAssertion,
|
|
) -> None:
|
|
"""Test that a pipeline error causes the error tone to be played."""
|
|
assert await async_setup_component(hass, "voip", {})
|
|
|
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
|
assert isinstance(satellite, VoipAssistSatellite)
|
|
|
|
done = asyncio.Event()
|
|
played_audio_bytes = b""
|
|
|
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
|
# Fake error
|
|
event_callback = kwargs["event_callback"]
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
type=assist_pipeline.PipelineEventType.ERROR,
|
|
data={"code": "error-code", "message": "error message"},
|
|
)
|
|
)
|
|
|
|
async def async_send_audio(audio_bytes: bytes, **kwargs):
|
|
nonlocal played_audio_bytes
|
|
|
|
# Should be error.pcm from components/voip
|
|
played_audio_bytes = audio_bytes
|
|
done.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
new=async_pipeline_from_audio_stream,
|
|
),
|
|
):
|
|
satellite._tones = Tones.ERROR
|
|
satellite.transport = Mock()
|
|
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
|
|
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
|
|
|
# Wait for error tone to be played
|
|
async with asyncio.timeout(1):
|
|
await done.wait()
|
|
|
|
assert sum(played_audio_bytes) > 0
|
|
assert played_audio_bytes == snapshot()
|