core/tests/components/wyoming/test_satellite.py

1286 lines
43 KiB
Python

"""Test Wyoming satellite."""
from __future__ import annotations
import asyncio
from collections.abc import Callable
import io
from typing import Any
from unittest.mock import patch
import wave
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.error import Error
from wyoming.event import Event
from wyoming.info import Info
from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
from wyoming.tts import Synthesize
from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection
from homeassistant.components import assist_pipeline, wyoming
from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.const import STATE_ON
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers import intent as intent_helper
from homeassistant.setup import async_setup_component
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
from tests.common import MockConfigEntry
async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"""Set up config entry for Wyoming satellite.
This is separated from the satellite_config_entry method in conftest.py so
we can patch functions before the satellite task is run during setup.
"""
entry = MockConfigEntry(
domain="wyoming",
data={
"host": "1.2.3.4",
"port": 1234,
},
title="Test Satellite",
)
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
return entry
def get_test_wav() -> bytes:
"""Get bytes for test WAV file."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(22050)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
# Single frame
wav_file.writeframes(b"123")
return wav_io.getvalue()
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
"""Satellite AsyncTcpClient."""
def __init__(self, responses: list[Event]) -> None:
"""Initialize client."""
super().__init__(responses)
self.connect_event = asyncio.Event()
self.run_satellite_event = asyncio.Event()
self.detect_event = asyncio.Event()
self.detection_event = asyncio.Event()
self.detection: Detection | None = None
self.transcribe_event = asyncio.Event()
self.transcribe: Transcribe | None = None
self.voice_started_event = asyncio.Event()
self.voice_started: VoiceStarted | None = None
self.voice_stopped_event = asyncio.Event()
self.voice_stopped: VoiceStopped | None = None
self.transcript_event = asyncio.Event()
self.transcript: Transcript | None = None
self.synthesize_event = asyncio.Event()
self.synthesize: Synthesize | None = None
self.tts_audio_start_event = asyncio.Event()
self.tts_audio_chunk_event = asyncio.Event()
self.tts_audio_stop_event = asyncio.Event()
self.tts_audio_chunk: AudioChunk | None = None
self.error_event = asyncio.Event()
self.error: Error | None = None
self.pong_event = asyncio.Event()
self.pong: Pong | None = None
self.ping_event = asyncio.Event()
self.ping: Ping | None = None
self.timer_started_event = asyncio.Event()
self.timer_started: TimerStarted | None = None
self.timer_updated_event = asyncio.Event()
self.timer_updated: TimerUpdated | None = None
self.timer_cancelled_event = asyncio.Event()
self.timer_cancelled: TimerCancelled | None = None
self.timer_finished_event = asyncio.Event()
self.timer_finished: TimerFinished | None = None
self._mic_audio_chunk = AudioChunk(
rate=16000, width=2, channels=1, audio=b"chunk"
).event()
async def connect(self) -> None:
"""Connect."""
self.connect_event.set()
async def write_event(self, event: Event):
"""Send."""
if RunSatellite.is_type(event.type):
self.run_satellite_event.set()
elif Detect.is_type(event.type):
self.detect_event.set()
elif Detection.is_type(event.type):
self.detection = Detection.from_event(event)
self.detection_event.set()
elif Transcribe.is_type(event.type):
self.transcribe = Transcribe.from_event(event)
self.transcribe_event.set()
elif VoiceStarted.is_type(event.type):
self.voice_started = VoiceStarted.from_event(event)
self.voice_started_event.set()
elif VoiceStopped.is_type(event.type):
self.voice_stopped = VoiceStopped.from_event(event)
self.voice_stopped_event.set()
elif Transcript.is_type(event.type):
self.transcript = Transcript.from_event(event)
self.transcript_event.set()
elif Synthesize.is_type(event.type):
self.synthesize = Synthesize.from_event(event)
self.synthesize_event.set()
elif AudioStart.is_type(event.type):
self.tts_audio_start_event.set()
elif AudioChunk.is_type(event.type):
self.tts_audio_chunk = AudioChunk.from_event(event)
self.tts_audio_chunk_event.set()
elif AudioStop.is_type(event.type):
self.tts_audio_stop_event.set()
elif Error.is_type(event.type):
self.error = Error.from_event(event)
self.error_event.set()
elif Pong.is_type(event.type):
self.pong = Pong.from_event(event)
self.pong_event.set()
elif Ping.is_type(event.type):
self.ping = Ping.from_event(event)
self.ping_event.set()
elif TimerStarted.is_type(event.type):
self.timer_started = TimerStarted.from_event(event)
self.timer_started_event.set()
elif TimerUpdated.is_type(event.type):
self.timer_updated = TimerUpdated.from_event(event)
self.timer_updated_event.set()
elif TimerCancelled.is_type(event.type):
self.timer_cancelled = TimerCancelled.from_event(event)
self.timer_cancelled_event.set()
elif TimerFinished.is_type(event.type):
self.timer_finished = TimerFinished.from_event(event)
self.timer_finished_event.set()
async def read_event(self) -> Event | None:
"""Receive."""
event = await super().read_event()
# Keep sending audio chunks instead of None
return event or self._mic_audio_chunk
def inject_event(self, event: Event) -> None:
"""Put an event in as the next response."""
self.responses = [event, *self.responses]
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
"""Test running a pipeline with a satellite."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE,
end_stage=PipelineStage.TTS,
restart_on_end=True,
).event(),
]
pipeline_kwargs: dict[str, Any] = {}
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
None
)
run_pipeline_called = asyncio.Event()
audio_chunk_received = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_kwargs, pipeline_event_callback
pipeline_kwargs = kwargs
pipeline_event_callback = event_callback
run_pipeline_called.set()
async for chunk in stt_stream:
if chunk:
audio_chunk_received.set()
break
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()),
),
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
assert device is not None
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Reset so we can check the pipeline is automatically restarted below
run_pipeline_called.clear()
assert pipeline_event_callback is not None
assert pipeline_kwargs.get("device_id") == device.device_id
# Test a ping
mock_client.inject_event(Ping("test-ping").event())
# Pong is expected with the same text
async with asyncio.timeout(1):
await mock_client.pong_event.wait()
assert mock_client.pong is not None
assert mock_client.pong.text == "test-ping"
# The client should have received the first ping
async with asyncio.timeout(1):
await mock_client.ping_event.wait()
assert mock_client.ping is not None
# Reset and send a pong back.
# We will get a second ping by the end of the test.
mock_client.ping_event.clear()
mock_client.ping = None
mock_client.inject_event(Pong().event())
# Start detecting wake word
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.WAKE_WORD_START
)
)
async with asyncio.timeout(1):
await mock_client.detect_event.wait()
assert not device.is_active
assert not device.is_muted
# Push in some audio
mock_client.inject_event(
AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
)
# Wake word is detected
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.WAKE_WORD_END,
{"wake_word_output": {"wake_word_id": "test_wake_word"}},
)
)
async with asyncio.timeout(1):
await mock_client.detection_event.wait()
assert mock_client.detection is not None
assert mock_client.detection.name == "test_wake_word"
# Speech-to-text started
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_START,
{"metadata": {"language": "en"}},
)
)
async with asyncio.timeout(1):
await mock_client.transcribe_event.wait()
assert mock_client.transcribe is not None
assert mock_client.transcribe.language == "en"
# "Assist in progress" sensor should be active now
assert device.is_active
# Push in some audio
mock_client.inject_event(
AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
)
# User started speaking
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234}
)
)
async with asyncio.timeout(1):
await mock_client.voice_started_event.wait()
assert mock_client.voice_started is not None
assert mock_client.voice_started.timestamp == 1234
# User stopped speaking
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678}
)
)
async with asyncio.timeout(1):
await mock_client.voice_stopped_event.wait()
assert mock_client.voice_stopped is not None
assert mock_client.voice_stopped.timestamp == 5678
# Speech-to-text transcription
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_END,
{"stt_output": {"text": "test transcript"}},
)
)
async with asyncio.timeout(1):
await mock_client.transcript_event.wait()
assert mock_client.transcript is not None
assert mock_client.transcript.text == "test transcript"
# Text-to-speech text
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_START,
{
"tts_input": "test text to speak",
"voice": "test voice",
},
)
)
async with asyncio.timeout(1):
await mock_client.synthesize_event.wait()
assert mock_client.synthesize is not None
assert mock_client.synthesize.text == "test text to speak"
assert mock_client.synthesize.voice is not None
assert mock_client.synthesize.voice.name == "test voice"
# Text-to-speech media
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_END,
{"tts_output": {"media_id": "test media id"}},
)
)
async with asyncio.timeout(1):
await mock_client.tts_audio_start_event.wait()
await mock_client.tts_audio_chunk_event.wait()
await mock_client.tts_audio_stop_event.wait()
# Verify audio chunk from test WAV
assert mock_client.tts_audio_chunk is not None
assert mock_client.tts_audio_chunk.rate == 22050
assert mock_client.tts_audio_chunk.width == 2
assert mock_client.tts_audio_chunk.channels == 1
assert mock_client.tts_audio_chunk.audio == b"123"
# Pipeline finished
pipeline_event_callback(
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
)
assert not device.is_active
# The client should have received another ping by now
async with asyncio.timeout(1):
await mock_client.ping_event.wait()
assert mock_client.ping is not None
# Pipeline should automatically restart
async with asyncio.timeout(1):
await run_pipeline_called.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_satellite_muted(hass: HomeAssistant) -> None:
"""Test callback for a satellite that has been muted."""
on_muted_event = asyncio.Event()
original_on_muted = WyomingAssistSatellite.on_muted
async def on_muted(self):
# Trigger original function
self._muted_changed_event.set()
await original_on_muted(self)
# Ensure satellite stops
self.is_running = False
# Proceed with test
self.device.set_is_muted(False)
on_muted_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient([]),
),
patch(
"homeassistant.components.wyoming.switch.WyomingSatelliteMuteSwitch.async_get_last_state",
return_value=State("switch.test_mute", STATE_ON),
),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_muted",
on_muted,
),
):
entry = await setup_config_entry(hass)
async with asyncio.timeout(1):
await on_muted_event.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_satellite_restart(hass: HomeAssistant) -> None:
"""Test pipeline loop restart after unexpected error."""
on_restart_event = asyncio.Event()
original_on_restart = WyomingAssistSatellite.on_restart
async def on_restart(self):
await original_on_restart(self)
self.stop_satellite()
on_restart_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._connect_and_loop",
side_effect=RuntimeError(),
),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
on_restart,
),
patch("homeassistant.components.wyoming.assist_satellite._RESTART_SECONDS", 0),
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await on_restart_event.wait()
async def test_satellite_reconnect(hass: HomeAssistant) -> None:
"""Test satellite reconnect call after connection refused."""
num_reconnects = 0
reconnect_event = asyncio.Event()
stopped_event = asyncio.Event()
original_on_reconnect = WyomingAssistSatellite.on_reconnect
async def on_reconnect(self):
await original_on_reconnect(self)
nonlocal num_reconnects
num_reconnects += 1
if num_reconnects >= 2:
reconnect_event.set()
self.stop_satellite()
async def on_stopped(self):
stopped_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient.connect",
side_effect=ConnectionRefusedError(),
),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_reconnect",
on_reconnect,
),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped",
on_stopped,
),
patch(
"homeassistant.components.wyoming.assist_satellite._RECONNECT_SECONDS", 0
),
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await reconnect_event.wait()
await stopped_event.wait()
async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None:
"""Test satellite disconnecting before pipeline run."""
on_restart_event = asyncio.Event()
async def on_restart(self):
self.stop_satellite()
on_restart_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
MockAsyncTcpClient([]), # no RunPipeline event
),
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
on_restart,
),
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await on_restart_event.wait()
# Pipeline should never have run
mock_run_pipeline.assert_not_called()
async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None:
"""Test satellite disconnecting during pipeline run."""
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
] # no audio chunks after RunPipeline
on_restart_event = asyncio.Event()
on_stopped_event = asyncio.Event()
async def on_restart(self):
# Pretend sensor got stuck on
self.device.is_active = True
self.stop_satellite()
on_restart_event.set()
async def on_stopped(self):
on_stopped_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
MockAsyncTcpClient(events),
),
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
on_restart,
),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped",
on_stopped,
),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
async with asyncio.timeout(1):
await on_restart_event.wait()
await on_stopped_event.wait()
# Pipeline should have run once
mock_run_pipeline.assert_called_once()
# Sensor should have been turned off
assert not device.is_active
async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
"""Test satellite error occurring during pipeline run."""
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
] # no audio chunks after RunPipeline
pipeline_event = asyncio.Event()
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
pipeline_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await pipeline_event.wait()
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
mock_run_pipeline.assert_called_once()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.ERROR,
{"code": "test code", "message": "test message"},
)
)
async with asyncio.timeout(1):
await mock_client.error_event.wait()
assert mock_client.error is not None
assert mock_client.error.text == "test message"
assert mock_client.error.code == "test code"
async def test_tts_not_wav(hass: HomeAssistant) -> None:
"""Test satellite receiving non-WAV audio from text-to-speech."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
original_stream_tts = WyomingAssistSatellite._stream_tts
error_event = asyncio.Event()
async def _stream_tts(self, media_id):
try:
await original_stream_tts(self, media_id)
except ValueError:
error_event.set()
events = [
RunPipeline(start_stage=PipelineStage.TTS, end_stage=PipelineStage.TTS).event(),
]
pipeline_event = asyncio.Event()
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
pipeline_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
return_value=("mp3", bytes(1)),
),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts",
_stream_tts,
),
):
entry = await setup_config_entry(hass)
async with asyncio.timeout(1):
await pipeline_event.wait()
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
mock_run_pipeline.assert_called_once()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
# Text-to-speech text
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_START,
{
"tts_input": "test text to speak",
"voice": "test voice",
},
)
)
async with asyncio.timeout(1):
await mock_client.synthesize_event.wait()
# Text-to-speech media
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_END,
{"tts_output": {"media_id": "test media id"}},
)
)
# Expect error because only WAV is supported
async with asyncio.timeout(1):
await error_event.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_pipeline_changed(hass: HomeAssistant) -> None:
"""Test that changing the pipeline setting stops the current pipeline."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
None
)
run_pipeline_called = asyncio.Event()
pipeline_stopped = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_event_callback
pipeline_event_callback = event_callback
run_pipeline_called.set()
async for _chunk in stt_stream:
pass
pipeline_stopped.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
# Pipeline has started
async with asyncio.timeout(1):
await run_pipeline_called.wait()
assert pipeline_event_callback is not None
# Change pipelines
device.set_pipeline_name("different pipeline")
# Running pipeline should be cancelled
async with asyncio.timeout(1):
await pipeline_stopped.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_audio_settings_changed(hass: HomeAssistant) -> None:
"""Test that changing audio settings stops the current pipeline."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
None
)
run_pipeline_called = asyncio.Event()
pipeline_stopped = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_event_callback
pipeline_event_callback = event_callback
run_pipeline_called.set()
async for _chunk in stt_stream:
pass
pipeline_stopped.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
# Pipeline has started
async with asyncio.timeout(1):
await run_pipeline_called.wait()
assert pipeline_event_callback is not None
# Change audio setting
device.set_noise_suppression_level(1)
# Running pipeline should be cancelled
async with asyncio.timeout(1):
await pipeline_stopped.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_invalid_stages(hass: HomeAssistant) -> None:
"""Test error when providing invalid pipeline stages."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
original_run_pipeline_once = WyomingAssistSatellite._run_pipeline_once
start_stage_event = asyncio.Event()
end_stage_event = asyncio.Event()
def _run_pipeline_once(self, run_pipeline, wake_word_phrase):
# Set bad start stage
run_pipeline.start_stage = PipelineStage.INTENT
run_pipeline.end_stage = PipelineStage.TTS
try:
original_run_pipeline_once(self, run_pipeline)
except ValueError:
start_stage_event.set()
# Set bad end stage
run_pipeline.start_stage = PipelineStage.WAKE
run_pipeline.end_stage = PipelineStage.INTENT
try:
original_run_pipeline_once(self, run_pipeline)
except ValueError:
end_stage_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._run_pipeline_once",
_run_pipeline_once,
),
):
entry = await setup_config_entry(hass)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
async with asyncio.timeout(1):
await start_stage_event.wait()
await end_stage_event.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
"""Test that an AudioStop message stops the current pipeline."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
None
)
run_pipeline_called = asyncio.Event()
pipeline_stopped = asyncio.Event()
async def async_pipeline_from_audio_stream(
hass: HomeAssistant,
context,
event_callback,
stt_metadata,
stt_stream,
**kwargs,
) -> None:
nonlocal pipeline_event_callback
pipeline_event_callback = event_callback
run_pipeline_called.set()
async for _chunk in stt_stream:
pass
pipeline_stopped.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
):
entry = await setup_config_entry(hass)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
# Pipeline has started
async with asyncio.timeout(1):
await run_pipeline_called.wait()
assert pipeline_event_callback is not None
# Client sends stop message
mock_client.inject_event(AudioStop().event())
# Running pipeline should be cancelled
async with asyncio.timeout(1):
await pipeline_stopped.wait()
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_wake_word_phrase(hass: HomeAssistant) -> None:
"""Test that wake word phrase from info is given to pipeline."""
events = [
# Fake local wake word detection
Info(satellite=SATELLITE_INFO.satellite, wake=WAKE_WORD_INFO.wake).event(),
Detection(name="Test Model").event(),
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
pipeline_event = asyncio.Event()
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
pipeline_event.set()
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
),
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await pipeline_event.wait()
# async_pipeline_from_audio_stream will receive the wake word phrase for
# deconfliction.
mock_run_pipeline.assert_called_once()
assert (
mock_run_pipeline.call_args.kwargs.get("wake_word_phrase") == "Test Phrase"
)
async def test_timers(hass: HomeAssistant) -> None:
"""Test timer events."""
assert await async_setup_component(hass, "intent", {})
with (
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient([]),
) as mock_client,
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
# Start timer
result = await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_START_TIMER,
{
"name": {"value": "test timer"},
"hours": {"value": 1},
"minutes": {"value": 2},
"seconds": {"value": 3},
},
device_id=device.device_id,
)
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
async with asyncio.timeout(1):
await mock_client.timer_started_event.wait()
timer_started = mock_client.timer_started
assert timer_started is not None
assert timer_started.id
assert timer_started.name == "test timer"
assert timer_started.start_hours == 1
assert timer_started.start_minutes == 2
assert timer_started.start_seconds == 3
assert timer_started.total_seconds == (1 * 60 * 60) + (2 * 60) + 3
# Pause
mock_client.timer_updated_event.clear()
result = await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_PAUSE_TIMER,
{},
device_id=device.device_id,
)
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
async with asyncio.timeout(1):
await mock_client.timer_updated_event.wait()
timer_updated = mock_client.timer_updated
assert timer_updated is not None
assert timer_updated.id == timer_started.id
assert not timer_updated.is_active
# Resume
mock_client.timer_updated_event.clear()
result = await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_UNPAUSE_TIMER,
{},
device_id=device.device_id,
)
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
async with asyncio.timeout(1):
await mock_client.timer_updated_event.wait()
timer_updated = mock_client.timer_updated
assert timer_updated is not None
assert timer_updated.id == timer_started.id
assert timer_updated.is_active
# Add time
mock_client.timer_updated_event.clear()
result = await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_INCREASE_TIMER,
{
"hours": {"value": 2},
"minutes": {"value": 3},
"seconds": {"value": 4},
},
device_id=device.device_id,
)
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
async with asyncio.timeout(1):
await mock_client.timer_updated_event.wait()
timer_updated = mock_client.timer_updated
assert timer_updated is not None
assert timer_updated.id == timer_started.id
assert timer_updated.total_seconds > timer_started.total_seconds
# Remove time
mock_client.timer_updated_event.clear()
result = await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_DECREASE_TIMER,
{
"hours": {"value": 2},
"minutes": {"value": 3},
"seconds": {"value": 5}, # remove 1 extra second
},
device_id=device.device_id,
)
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
async with asyncio.timeout(1):
await mock_client.timer_updated_event.wait()
timer_updated = mock_client.timer_updated
assert timer_updated is not None
assert timer_updated.id == timer_started.id
assert timer_updated.total_seconds < timer_started.total_seconds
# Cancel
result = await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_CANCEL_TIMER,
{},
device_id=device.device_id,
)
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
async with asyncio.timeout(1):
await mock_client.timer_cancelled_event.wait()
timer_cancelled = mock_client.timer_cancelled
assert timer_cancelled is not None
assert timer_cancelled.id == timer_started.id
# Start a new timer
mock_client.timer_started_event.clear()
result = await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_START_TIMER,
{
"name": {"value": "test timer"},
"minutes": {"value": 1},
},
device_id=device.device_id,
)
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
async with asyncio.timeout(1):
await mock_client.timer_started_event.wait()
timer_started = mock_client.timer_started
assert timer_started is not None
# Finished
result = await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_DECREASE_TIMER,
{
"minutes": {"value": 1}, # force finish
},
device_id=device.device_id,
)
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
async with asyncio.timeout(1):
await mock_client.timer_finished_event.wait()
timer_finished = mock_client.timer_finished
assert timer_finished is not None
assert timer_finished.id == timer_started.id