mirror of https://github.com/home-assistant/core
1286 lines
43 KiB
Python
1286 lines
43 KiB
Python
"""Test Wyoming satellite."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import Callable
|
|
import io
|
|
from typing import Any
|
|
from unittest.mock import patch
|
|
import wave
|
|
|
|
from wyoming.asr import Transcribe, Transcript
|
|
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
|
from wyoming.error import Error
|
|
from wyoming.event import Event
|
|
from wyoming.info import Info
|
|
from wyoming.ping import Ping, Pong
|
|
from wyoming.pipeline import PipelineStage, RunPipeline
|
|
from wyoming.satellite import RunSatellite
|
|
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
|
|
from wyoming.tts import Synthesize
|
|
from wyoming.vad import VoiceStarted, VoiceStopped
|
|
from wyoming.wake import Detect, Detection
|
|
|
|
from homeassistant.components import assist_pipeline, wyoming
|
|
from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite
|
|
from homeassistant.components.wyoming.devices import SatelliteDevice
|
|
from homeassistant.const import STATE_ON
|
|
from homeassistant.core import HomeAssistant, State
|
|
from homeassistant.helpers import intent as intent_helper
|
|
from homeassistant.setup import async_setup_component
|
|
|
|
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
|
|
|
|
from tests.common import MockConfigEntry
|
|
|
|
|
|
async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
|
"""Set up config entry for Wyoming satellite.
|
|
|
|
This is separated from the satellite_config_entry method in conftest.py so
|
|
we can patch functions before the satellite task is run during setup.
|
|
"""
|
|
entry = MockConfigEntry(
|
|
domain="wyoming",
|
|
data={
|
|
"host": "1.2.3.4",
|
|
"port": 1234,
|
|
},
|
|
title="Test Satellite",
|
|
)
|
|
entry.add_to_hass(hass)
|
|
await hass.config_entries.async_setup(entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
return entry
|
|
|
|
|
|
def get_test_wav() -> bytes:
|
|
"""Get bytes for test WAV file."""
|
|
with io.BytesIO() as wav_io:
|
|
with wave.open(wav_io, "wb") as wav_file:
|
|
wav_file.setframerate(22050)
|
|
wav_file.setsampwidth(2)
|
|
wav_file.setnchannels(1)
|
|
|
|
# Single frame
|
|
wav_file.writeframes(b"123")
|
|
|
|
return wav_io.getvalue()
|
|
|
|
|
|
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
|
"""Satellite AsyncTcpClient."""
|
|
|
|
def __init__(self, responses: list[Event]) -> None:
|
|
"""Initialize client."""
|
|
super().__init__(responses)
|
|
|
|
self.connect_event = asyncio.Event()
|
|
self.run_satellite_event = asyncio.Event()
|
|
self.detect_event = asyncio.Event()
|
|
|
|
self.detection_event = asyncio.Event()
|
|
self.detection: Detection | None = None
|
|
|
|
self.transcribe_event = asyncio.Event()
|
|
self.transcribe: Transcribe | None = None
|
|
|
|
self.voice_started_event = asyncio.Event()
|
|
self.voice_started: VoiceStarted | None = None
|
|
|
|
self.voice_stopped_event = asyncio.Event()
|
|
self.voice_stopped: VoiceStopped | None = None
|
|
|
|
self.transcript_event = asyncio.Event()
|
|
self.transcript: Transcript | None = None
|
|
|
|
self.synthesize_event = asyncio.Event()
|
|
self.synthesize: Synthesize | None = None
|
|
|
|
self.tts_audio_start_event = asyncio.Event()
|
|
self.tts_audio_chunk_event = asyncio.Event()
|
|
self.tts_audio_stop_event = asyncio.Event()
|
|
self.tts_audio_chunk: AudioChunk | None = None
|
|
|
|
self.error_event = asyncio.Event()
|
|
self.error: Error | None = None
|
|
|
|
self.pong_event = asyncio.Event()
|
|
self.pong: Pong | None = None
|
|
|
|
self.ping_event = asyncio.Event()
|
|
self.ping: Ping | None = None
|
|
|
|
self.timer_started_event = asyncio.Event()
|
|
self.timer_started: TimerStarted | None = None
|
|
|
|
self.timer_updated_event = asyncio.Event()
|
|
self.timer_updated: TimerUpdated | None = None
|
|
|
|
self.timer_cancelled_event = asyncio.Event()
|
|
self.timer_cancelled: TimerCancelled | None = None
|
|
|
|
self.timer_finished_event = asyncio.Event()
|
|
self.timer_finished: TimerFinished | None = None
|
|
|
|
self._mic_audio_chunk = AudioChunk(
|
|
rate=16000, width=2, channels=1, audio=b"chunk"
|
|
).event()
|
|
|
|
async def connect(self) -> None:
|
|
"""Connect."""
|
|
self.connect_event.set()
|
|
|
|
async def write_event(self, event: Event):
|
|
"""Send."""
|
|
if RunSatellite.is_type(event.type):
|
|
self.run_satellite_event.set()
|
|
elif Detect.is_type(event.type):
|
|
self.detect_event.set()
|
|
elif Detection.is_type(event.type):
|
|
self.detection = Detection.from_event(event)
|
|
self.detection_event.set()
|
|
elif Transcribe.is_type(event.type):
|
|
self.transcribe = Transcribe.from_event(event)
|
|
self.transcribe_event.set()
|
|
elif VoiceStarted.is_type(event.type):
|
|
self.voice_started = VoiceStarted.from_event(event)
|
|
self.voice_started_event.set()
|
|
elif VoiceStopped.is_type(event.type):
|
|
self.voice_stopped = VoiceStopped.from_event(event)
|
|
self.voice_stopped_event.set()
|
|
elif Transcript.is_type(event.type):
|
|
self.transcript = Transcript.from_event(event)
|
|
self.transcript_event.set()
|
|
elif Synthesize.is_type(event.type):
|
|
self.synthesize = Synthesize.from_event(event)
|
|
self.synthesize_event.set()
|
|
elif AudioStart.is_type(event.type):
|
|
self.tts_audio_start_event.set()
|
|
elif AudioChunk.is_type(event.type):
|
|
self.tts_audio_chunk = AudioChunk.from_event(event)
|
|
self.tts_audio_chunk_event.set()
|
|
elif AudioStop.is_type(event.type):
|
|
self.tts_audio_stop_event.set()
|
|
elif Error.is_type(event.type):
|
|
self.error = Error.from_event(event)
|
|
self.error_event.set()
|
|
elif Pong.is_type(event.type):
|
|
self.pong = Pong.from_event(event)
|
|
self.pong_event.set()
|
|
elif Ping.is_type(event.type):
|
|
self.ping = Ping.from_event(event)
|
|
self.ping_event.set()
|
|
elif TimerStarted.is_type(event.type):
|
|
self.timer_started = TimerStarted.from_event(event)
|
|
self.timer_started_event.set()
|
|
elif TimerUpdated.is_type(event.type):
|
|
self.timer_updated = TimerUpdated.from_event(event)
|
|
self.timer_updated_event.set()
|
|
elif TimerCancelled.is_type(event.type):
|
|
self.timer_cancelled = TimerCancelled.from_event(event)
|
|
self.timer_cancelled_event.set()
|
|
elif TimerFinished.is_type(event.type):
|
|
self.timer_finished = TimerFinished.from_event(event)
|
|
self.timer_finished_event.set()
|
|
|
|
async def read_event(self) -> Event | None:
|
|
"""Receive."""
|
|
event = await super().read_event()
|
|
|
|
# Keep sending audio chunks instead of None
|
|
return event or self._mic_audio_chunk
|
|
|
|
def inject_event(self, event: Event) -> None:
|
|
"""Put an event in as the next response."""
|
|
self.responses = [event, *self.responses]
|
|
|
|
|
|
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|
"""Test running a pipeline with a satellite."""
|
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
|
|
|
events = [
|
|
RunPipeline(
|
|
start_stage=PipelineStage.WAKE,
|
|
end_stage=PipelineStage.TTS,
|
|
restart_on_end=True,
|
|
).event(),
|
|
]
|
|
|
|
pipeline_kwargs: dict[str, Any] = {}
|
|
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
|
|
None
|
|
)
|
|
run_pipeline_called = asyncio.Event()
|
|
audio_chunk_received = asyncio.Event()
|
|
|
|
async def async_pipeline_from_audio_stream(
|
|
hass: HomeAssistant,
|
|
context,
|
|
event_callback,
|
|
stt_metadata,
|
|
stt_stream,
|
|
**kwargs,
|
|
) -> None:
|
|
nonlocal pipeline_kwargs, pipeline_event_callback
|
|
pipeline_kwargs = kwargs
|
|
pipeline_event_callback = event_callback
|
|
|
|
run_pipeline_called.set()
|
|
async for chunk in stt_stream:
|
|
if chunk:
|
|
audio_chunk_received.set()
|
|
break
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient(events),
|
|
) as mock_client,
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
async_pipeline_from_audio_stream,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
|
|
return_value=("wav", get_test_wav()),
|
|
),
|
|
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
|
):
|
|
entry = await setup_config_entry(hass)
|
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
|
assert device is not None
|
|
|
|
async with asyncio.timeout(1):
|
|
await mock_client.connect_event.wait()
|
|
await mock_client.run_satellite_event.wait()
|
|
|
|
async with asyncio.timeout(1):
|
|
await run_pipeline_called.wait()
|
|
|
|
# Reset so we can check the pipeline is automatically restarted below
|
|
run_pipeline_called.clear()
|
|
|
|
assert pipeline_event_callback is not None
|
|
assert pipeline_kwargs.get("device_id") == device.device_id
|
|
|
|
# Test a ping
|
|
mock_client.inject_event(Ping("test-ping").event())
|
|
|
|
# Pong is expected with the same text
|
|
async with asyncio.timeout(1):
|
|
await mock_client.pong_event.wait()
|
|
|
|
assert mock_client.pong is not None
|
|
assert mock_client.pong.text == "test-ping"
|
|
|
|
# The client should have received the first ping
|
|
async with asyncio.timeout(1):
|
|
await mock_client.ping_event.wait()
|
|
|
|
assert mock_client.ping is not None
|
|
|
|
# Reset and send a pong back.
|
|
# We will get a second ping by the end of the test.
|
|
mock_client.ping_event.clear()
|
|
mock_client.ping = None
|
|
mock_client.inject_event(Pong().event())
|
|
|
|
# Start detecting wake word
|
|
pipeline_event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.WAKE_WORD_START
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.detect_event.wait()
|
|
|
|
assert not device.is_active
|
|
assert not device.is_muted
|
|
|
|
# Push in some audio
|
|
mock_client.inject_event(
|
|
AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
|
|
)
|
|
|
|
# Wake word is detected
|
|
pipeline_event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.WAKE_WORD_END,
|
|
{"wake_word_output": {"wake_word_id": "test_wake_word"}},
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.detection_event.wait()
|
|
|
|
assert mock_client.detection is not None
|
|
assert mock_client.detection.name == "test_wake_word"
|
|
|
|
# Speech-to-text started
|
|
pipeline_event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.STT_START,
|
|
{"metadata": {"language": "en"}},
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.transcribe_event.wait()
|
|
|
|
assert mock_client.transcribe is not None
|
|
assert mock_client.transcribe.language == "en"
|
|
|
|
# "Assist in progress" sensor should be active now
|
|
assert device.is_active
|
|
|
|
# Push in some audio
|
|
mock_client.inject_event(
|
|
AudioChunk(rate=16000, width=2, channels=1, audio=bytes(1024)).event()
|
|
)
|
|
|
|
# User started speaking
|
|
pipeline_event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234}
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.voice_started_event.wait()
|
|
|
|
assert mock_client.voice_started is not None
|
|
assert mock_client.voice_started.timestamp == 1234
|
|
|
|
# User stopped speaking
|
|
pipeline_event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678}
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.voice_stopped_event.wait()
|
|
|
|
assert mock_client.voice_stopped is not None
|
|
assert mock_client.voice_stopped.timestamp == 5678
|
|
|
|
# Speech-to-text transcription
|
|
pipeline_event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.STT_END,
|
|
{"stt_output": {"text": "test transcript"}},
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.transcript_event.wait()
|
|
|
|
assert mock_client.transcript is not None
|
|
assert mock_client.transcript.text == "test transcript"
|
|
|
|
# Text-to-speech text
|
|
pipeline_event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.TTS_START,
|
|
{
|
|
"tts_input": "test text to speak",
|
|
"voice": "test voice",
|
|
},
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.synthesize_event.wait()
|
|
|
|
assert mock_client.synthesize is not None
|
|
assert mock_client.synthesize.text == "test text to speak"
|
|
assert mock_client.synthesize.voice is not None
|
|
assert mock_client.synthesize.voice.name == "test voice"
|
|
|
|
# Text-to-speech media
|
|
pipeline_event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.TTS_END,
|
|
{"tts_output": {"media_id": "test media id"}},
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.tts_audio_start_event.wait()
|
|
await mock_client.tts_audio_chunk_event.wait()
|
|
await mock_client.tts_audio_stop_event.wait()
|
|
|
|
# Verify audio chunk from test WAV
|
|
assert mock_client.tts_audio_chunk is not None
|
|
assert mock_client.tts_audio_chunk.rate == 22050
|
|
assert mock_client.tts_audio_chunk.width == 2
|
|
assert mock_client.tts_audio_chunk.channels == 1
|
|
assert mock_client.tts_audio_chunk.audio == b"123"
|
|
|
|
# Pipeline finished
|
|
pipeline_event_callback(
|
|
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
|
|
)
|
|
assert not device.is_active
|
|
|
|
# The client should have received another ping by now
|
|
async with asyncio.timeout(1):
|
|
await mock_client.ping_event.wait()
|
|
|
|
assert mock_client.ping is not None
|
|
|
|
# Pipeline should automatically restart
|
|
async with asyncio.timeout(1):
|
|
await run_pipeline_called.wait()
|
|
|
|
# Stop the satellite
|
|
await hass.config_entries.async_unload(entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
async def test_satellite_muted(hass: HomeAssistant) -> None:
|
|
"""Test callback for a satellite that has been muted."""
|
|
on_muted_event = asyncio.Event()
|
|
|
|
original_on_muted = WyomingAssistSatellite.on_muted
|
|
|
|
async def on_muted(self):
|
|
# Trigger original function
|
|
self._muted_changed_event.set()
|
|
await original_on_muted(self)
|
|
|
|
# Ensure satellite stops
|
|
self.is_running = False
|
|
|
|
# Proceed with test
|
|
self.device.set_is_muted(False)
|
|
on_muted_event.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient([]),
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.switch.WyomingSatelliteMuteSwitch.async_get_last_state",
|
|
return_value=State("switch.test_mute", STATE_ON),
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_muted",
|
|
on_muted,
|
|
),
|
|
):
|
|
entry = await setup_config_entry(hass)
|
|
async with asyncio.timeout(1):
|
|
await on_muted_event.wait()
|
|
|
|
# Stop the satellite
|
|
await hass.config_entries.async_unload(entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
async def test_satellite_restart(hass: HomeAssistant) -> None:
|
|
"""Test pipeline loop restart after unexpected error."""
|
|
on_restart_event = asyncio.Event()
|
|
|
|
original_on_restart = WyomingAssistSatellite.on_restart
|
|
|
|
async def on_restart(self):
|
|
await original_on_restart(self)
|
|
self.stop_satellite()
|
|
on_restart_event.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._connect_and_loop",
|
|
side_effect=RuntimeError(),
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
|
|
on_restart,
|
|
),
|
|
patch("homeassistant.components.wyoming.assist_satellite._RESTART_SECONDS", 0),
|
|
):
|
|
await setup_config_entry(hass)
|
|
async with asyncio.timeout(1):
|
|
await on_restart_event.wait()
|
|
|
|
|
|
async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
|
"""Test satellite reconnect call after connection refused."""
|
|
num_reconnects = 0
|
|
reconnect_event = asyncio.Event()
|
|
stopped_event = asyncio.Event()
|
|
|
|
original_on_reconnect = WyomingAssistSatellite.on_reconnect
|
|
|
|
async def on_reconnect(self):
|
|
await original_on_reconnect(self)
|
|
|
|
nonlocal num_reconnects
|
|
num_reconnects += 1
|
|
if num_reconnects >= 2:
|
|
reconnect_event.set()
|
|
self.stop_satellite()
|
|
|
|
async def on_stopped(self):
|
|
stopped_event.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient.connect",
|
|
side_effect=ConnectionRefusedError(),
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_reconnect",
|
|
on_reconnect,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped",
|
|
on_stopped,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite._RECONNECT_SECONDS", 0
|
|
),
|
|
):
|
|
await setup_config_entry(hass)
|
|
async with asyncio.timeout(1):
|
|
await reconnect_event.wait()
|
|
await stopped_event.wait()
|
|
|
|
|
|
async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None:
|
|
"""Test satellite disconnecting before pipeline run."""
|
|
on_restart_event = asyncio.Event()
|
|
|
|
async def on_restart(self):
|
|
self.stop_satellite()
|
|
on_restart_event.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
MockAsyncTcpClient([]), # no RunPipeline event
|
|
),
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
) as mock_run_pipeline,
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
|
|
on_restart,
|
|
),
|
|
):
|
|
await setup_config_entry(hass)
|
|
async with asyncio.timeout(1):
|
|
await on_restart_event.wait()
|
|
|
|
# Pipeline should never have run
|
|
mock_run_pipeline.assert_not_called()
|
|
|
|
|
|
async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None:
|
|
"""Test satellite disconnecting during pipeline run."""
|
|
events = [
|
|
RunPipeline(
|
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
|
).event(),
|
|
] # no audio chunks after RunPipeline
|
|
|
|
on_restart_event = asyncio.Event()
|
|
on_stopped_event = asyncio.Event()
|
|
|
|
async def on_restart(self):
|
|
# Pretend sensor got stuck on
|
|
self.device.is_active = True
|
|
self.stop_satellite()
|
|
on_restart_event.set()
|
|
|
|
async def on_stopped(self):
|
|
on_stopped_event.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
MockAsyncTcpClient(events),
|
|
),
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
) as mock_run_pipeline,
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_restart",
|
|
on_restart,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite.on_stopped",
|
|
on_stopped,
|
|
),
|
|
):
|
|
entry = await setup_config_entry(hass)
|
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
|
|
|
async with asyncio.timeout(1):
|
|
await on_restart_event.wait()
|
|
await on_stopped_event.wait()
|
|
|
|
# Pipeline should have run once
|
|
mock_run_pipeline.assert_called_once()
|
|
|
|
# Sensor should have been turned off
|
|
assert not device.is_active
|
|
|
|
|
|
async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
|
|
"""Test satellite error occurring during pipeline run."""
|
|
events = [
|
|
RunPipeline(
|
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
|
).event(),
|
|
] # no audio chunks after RunPipeline
|
|
|
|
pipeline_event = asyncio.Event()
|
|
|
|
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
|
pipeline_event.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient(events),
|
|
) as mock_client,
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
wraps=_async_pipeline_from_audio_stream,
|
|
) as mock_run_pipeline,
|
|
):
|
|
await setup_config_entry(hass)
|
|
|
|
async with asyncio.timeout(1):
|
|
await pipeline_event.wait()
|
|
await mock_client.connect_event.wait()
|
|
await mock_client.run_satellite_event.wait()
|
|
|
|
mock_run_pipeline.assert_called_once()
|
|
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.ERROR,
|
|
{"code": "test code", "message": "test message"},
|
|
)
|
|
)
|
|
|
|
async with asyncio.timeout(1):
|
|
await mock_client.error_event.wait()
|
|
|
|
assert mock_client.error is not None
|
|
assert mock_client.error.text == "test message"
|
|
assert mock_client.error.code == "test code"
|
|
|
|
|
|
async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
|
"""Test satellite receiving non-WAV audio from text-to-speech."""
|
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
|
|
|
original_stream_tts = WyomingAssistSatellite._stream_tts
|
|
error_event = asyncio.Event()
|
|
|
|
async def _stream_tts(self, media_id):
|
|
try:
|
|
await original_stream_tts(self, media_id)
|
|
except ValueError:
|
|
error_event.set()
|
|
|
|
events = [
|
|
RunPipeline(start_stage=PipelineStage.TTS, end_stage=PipelineStage.TTS).event(),
|
|
]
|
|
pipeline_event = asyncio.Event()
|
|
|
|
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
|
pipeline_event.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient(events),
|
|
) as mock_client,
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
wraps=_async_pipeline_from_audio_stream,
|
|
) as mock_run_pipeline,
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
|
|
return_value=("mp3", bytes(1)),
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts",
|
|
_stream_tts,
|
|
),
|
|
):
|
|
entry = await setup_config_entry(hass)
|
|
|
|
async with asyncio.timeout(1):
|
|
await pipeline_event.wait()
|
|
await mock_client.connect_event.wait()
|
|
await mock_client.run_satellite_event.wait()
|
|
|
|
mock_run_pipeline.assert_called_once()
|
|
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
|
|
|
# Text-to-speech text
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.TTS_START,
|
|
{
|
|
"tts_input": "test text to speak",
|
|
"voice": "test voice",
|
|
},
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.synthesize_event.wait()
|
|
|
|
# Text-to-speech media
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.TTS_END,
|
|
{"tts_output": {"media_id": "test media id"}},
|
|
)
|
|
)
|
|
|
|
# Expect error because only WAV is supported
|
|
async with asyncio.timeout(1):
|
|
await error_event.wait()
|
|
|
|
# Stop the satellite
|
|
await hass.config_entries.async_unload(entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
async def test_pipeline_changed(hass: HomeAssistant) -> None:
|
|
"""Test that changing the pipeline setting stops the current pipeline."""
|
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
|
|
|
events = [
|
|
RunPipeline(
|
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
|
).event(),
|
|
]
|
|
|
|
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
|
|
None
|
|
)
|
|
run_pipeline_called = asyncio.Event()
|
|
pipeline_stopped = asyncio.Event()
|
|
|
|
async def async_pipeline_from_audio_stream(
|
|
hass: HomeAssistant,
|
|
context,
|
|
event_callback,
|
|
stt_metadata,
|
|
stt_stream,
|
|
**kwargs,
|
|
) -> None:
|
|
nonlocal pipeline_event_callback
|
|
pipeline_event_callback = event_callback
|
|
|
|
run_pipeline_called.set()
|
|
async for _chunk in stt_stream:
|
|
pass
|
|
|
|
pipeline_stopped.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient(events),
|
|
) as mock_client,
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
async_pipeline_from_audio_stream,
|
|
),
|
|
):
|
|
entry = await setup_config_entry(hass)
|
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
|
|
|
async with asyncio.timeout(1):
|
|
await mock_client.connect_event.wait()
|
|
await mock_client.run_satellite_event.wait()
|
|
|
|
# Pipeline has started
|
|
async with asyncio.timeout(1):
|
|
await run_pipeline_called.wait()
|
|
|
|
assert pipeline_event_callback is not None
|
|
|
|
# Change pipelines
|
|
device.set_pipeline_name("different pipeline")
|
|
|
|
# Running pipeline should be cancelled
|
|
async with asyncio.timeout(1):
|
|
await pipeline_stopped.wait()
|
|
|
|
# Stop the satellite
|
|
await hass.config_entries.async_unload(entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
async def test_audio_settings_changed(hass: HomeAssistant) -> None:
|
|
"""Test that changing audio settings stops the current pipeline."""
|
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
|
|
|
events = [
|
|
RunPipeline(
|
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
|
).event(),
|
|
]
|
|
|
|
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
|
|
None
|
|
)
|
|
run_pipeline_called = asyncio.Event()
|
|
pipeline_stopped = asyncio.Event()
|
|
|
|
async def async_pipeline_from_audio_stream(
|
|
hass: HomeAssistant,
|
|
context,
|
|
event_callback,
|
|
stt_metadata,
|
|
stt_stream,
|
|
**kwargs,
|
|
) -> None:
|
|
nonlocal pipeline_event_callback
|
|
pipeline_event_callback = event_callback
|
|
|
|
run_pipeline_called.set()
|
|
async for _chunk in stt_stream:
|
|
pass
|
|
|
|
pipeline_stopped.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient(events),
|
|
) as mock_client,
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
async_pipeline_from_audio_stream,
|
|
),
|
|
):
|
|
entry = await setup_config_entry(hass)
|
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
|
|
|
async with asyncio.timeout(1):
|
|
await mock_client.connect_event.wait()
|
|
await mock_client.run_satellite_event.wait()
|
|
|
|
# Pipeline has started
|
|
async with asyncio.timeout(1):
|
|
await run_pipeline_called.wait()
|
|
|
|
assert pipeline_event_callback is not None
|
|
|
|
# Change audio setting
|
|
device.set_noise_suppression_level(1)
|
|
|
|
# Running pipeline should be cancelled
|
|
async with asyncio.timeout(1):
|
|
await pipeline_stopped.wait()
|
|
|
|
# Stop the satellite
|
|
await hass.config_entries.async_unload(entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
async def test_invalid_stages(hass: HomeAssistant) -> None:
|
|
"""Test error when providing invalid pipeline stages."""
|
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
|
|
|
events = [
|
|
RunPipeline(
|
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
|
).event(),
|
|
]
|
|
|
|
original_run_pipeline_once = WyomingAssistSatellite._run_pipeline_once
|
|
start_stage_event = asyncio.Event()
|
|
end_stage_event = asyncio.Event()
|
|
|
|
def _run_pipeline_once(self, run_pipeline, wake_word_phrase):
|
|
# Set bad start stage
|
|
run_pipeline.start_stage = PipelineStage.INTENT
|
|
run_pipeline.end_stage = PipelineStage.TTS
|
|
|
|
try:
|
|
original_run_pipeline_once(self, run_pipeline)
|
|
except ValueError:
|
|
start_stage_event.set()
|
|
|
|
# Set bad end stage
|
|
run_pipeline.start_stage = PipelineStage.WAKE
|
|
run_pipeline.end_stage = PipelineStage.INTENT
|
|
|
|
try:
|
|
original_run_pipeline_once(self, run_pipeline)
|
|
except ValueError:
|
|
end_stage_event.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient(events),
|
|
) as mock_client,
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._run_pipeline_once",
|
|
_run_pipeline_once,
|
|
),
|
|
):
|
|
entry = await setup_config_entry(hass)
|
|
|
|
async with asyncio.timeout(1):
|
|
await mock_client.connect_event.wait()
|
|
await mock_client.run_satellite_event.wait()
|
|
|
|
async with asyncio.timeout(1):
|
|
await start_stage_event.wait()
|
|
await end_stage_event.wait()
|
|
|
|
# Stop the satellite
|
|
await hass.config_entries.async_unload(entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
async def test_client_stops_pipeline(hass: HomeAssistant) -> None:
|
|
"""Test that an AudioStop message stops the current pipeline."""
|
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
|
|
|
events = [
|
|
RunPipeline(
|
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
|
).event(),
|
|
]
|
|
|
|
pipeline_event_callback: Callable[[assist_pipeline.PipelineEvent], None] | None = (
|
|
None
|
|
)
|
|
run_pipeline_called = asyncio.Event()
|
|
pipeline_stopped = asyncio.Event()
|
|
|
|
async def async_pipeline_from_audio_stream(
|
|
hass: HomeAssistant,
|
|
context,
|
|
event_callback,
|
|
stt_metadata,
|
|
stt_stream,
|
|
**kwargs,
|
|
) -> None:
|
|
nonlocal pipeline_event_callback
|
|
pipeline_event_callback = event_callback
|
|
|
|
run_pipeline_called.set()
|
|
async for _chunk in stt_stream:
|
|
pass
|
|
|
|
pipeline_stopped.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient(events),
|
|
) as mock_client,
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
async_pipeline_from_audio_stream,
|
|
),
|
|
):
|
|
entry = await setup_config_entry(hass)
|
|
|
|
async with asyncio.timeout(1):
|
|
await mock_client.connect_event.wait()
|
|
await mock_client.run_satellite_event.wait()
|
|
|
|
# Pipeline has started
|
|
async with asyncio.timeout(1):
|
|
await run_pipeline_called.wait()
|
|
|
|
assert pipeline_event_callback is not None
|
|
|
|
# Client sends stop message
|
|
mock_client.inject_event(AudioStop().event())
|
|
|
|
# Running pipeline should be cancelled
|
|
async with asyncio.timeout(1):
|
|
await pipeline_stopped.wait()
|
|
|
|
# Stop the satellite
|
|
await hass.config_entries.async_unload(entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
async def test_wake_word_phrase(hass: HomeAssistant) -> None:
|
|
"""Test that wake word phrase from info is given to pipeline."""
|
|
events = [
|
|
# Fake local wake word detection
|
|
Info(satellite=SATELLITE_INFO.satellite, wake=WAKE_WORD_INFO.wake).event(),
|
|
Detection(name="Test Model").event(),
|
|
RunPipeline(
|
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
|
).event(),
|
|
]
|
|
|
|
pipeline_event = asyncio.Event()
|
|
|
|
def _async_pipeline_from_audio_stream(*args: Any, **kwargs: Any) -> None:
|
|
pipeline_event.set()
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient(events),
|
|
),
|
|
patch(
|
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
wraps=_async_pipeline_from_audio_stream,
|
|
) as mock_run_pipeline,
|
|
):
|
|
await setup_config_entry(hass)
|
|
|
|
async with asyncio.timeout(1):
|
|
await pipeline_event.wait()
|
|
|
|
# async_pipeline_from_audio_stream will receive the wake word phrase for
|
|
# deconfliction.
|
|
mock_run_pipeline.assert_called_once()
|
|
assert (
|
|
mock_run_pipeline.call_args.kwargs.get("wake_word_phrase") == "Test Phrase"
|
|
)
|
|
|
|
|
|
async def test_timers(hass: HomeAssistant) -> None:
|
|
"""Test timer events."""
|
|
assert await async_setup_component(hass, "intent", {})
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
),
|
|
patch(
|
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient([]),
|
|
) as mock_client,
|
|
):
|
|
entry = await setup_config_entry(hass)
|
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
|
|
|
async with asyncio.timeout(1):
|
|
await mock_client.connect_event.wait()
|
|
await mock_client.run_satellite_event.wait()
|
|
|
|
# Start timer
|
|
result = await intent_helper.async_handle(
|
|
hass,
|
|
"test",
|
|
intent_helper.INTENT_START_TIMER,
|
|
{
|
|
"name": {"value": "test timer"},
|
|
"hours": {"value": 1},
|
|
"minutes": {"value": 2},
|
|
"seconds": {"value": 3},
|
|
},
|
|
device_id=device.device_id,
|
|
)
|
|
|
|
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
|
|
async with asyncio.timeout(1):
|
|
await mock_client.timer_started_event.wait()
|
|
timer_started = mock_client.timer_started
|
|
assert timer_started is not None
|
|
assert timer_started.id
|
|
assert timer_started.name == "test timer"
|
|
assert timer_started.start_hours == 1
|
|
assert timer_started.start_minutes == 2
|
|
assert timer_started.start_seconds == 3
|
|
assert timer_started.total_seconds == (1 * 60 * 60) + (2 * 60) + 3
|
|
|
|
# Pause
|
|
mock_client.timer_updated_event.clear()
|
|
result = await intent_helper.async_handle(
|
|
hass,
|
|
"test",
|
|
intent_helper.INTENT_PAUSE_TIMER,
|
|
{},
|
|
device_id=device.device_id,
|
|
)
|
|
|
|
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
|
|
async with asyncio.timeout(1):
|
|
await mock_client.timer_updated_event.wait()
|
|
timer_updated = mock_client.timer_updated
|
|
assert timer_updated is not None
|
|
assert timer_updated.id == timer_started.id
|
|
assert not timer_updated.is_active
|
|
|
|
# Resume
|
|
mock_client.timer_updated_event.clear()
|
|
result = await intent_helper.async_handle(
|
|
hass,
|
|
"test",
|
|
intent_helper.INTENT_UNPAUSE_TIMER,
|
|
{},
|
|
device_id=device.device_id,
|
|
)
|
|
|
|
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
|
|
async with asyncio.timeout(1):
|
|
await mock_client.timer_updated_event.wait()
|
|
timer_updated = mock_client.timer_updated
|
|
assert timer_updated is not None
|
|
assert timer_updated.id == timer_started.id
|
|
assert timer_updated.is_active
|
|
|
|
# Add time
|
|
mock_client.timer_updated_event.clear()
|
|
result = await intent_helper.async_handle(
|
|
hass,
|
|
"test",
|
|
intent_helper.INTENT_INCREASE_TIMER,
|
|
{
|
|
"hours": {"value": 2},
|
|
"minutes": {"value": 3},
|
|
"seconds": {"value": 4},
|
|
},
|
|
device_id=device.device_id,
|
|
)
|
|
|
|
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
|
|
async with asyncio.timeout(1):
|
|
await mock_client.timer_updated_event.wait()
|
|
timer_updated = mock_client.timer_updated
|
|
assert timer_updated is not None
|
|
assert timer_updated.id == timer_started.id
|
|
assert timer_updated.total_seconds > timer_started.total_seconds
|
|
|
|
# Remove time
|
|
mock_client.timer_updated_event.clear()
|
|
result = await intent_helper.async_handle(
|
|
hass,
|
|
"test",
|
|
intent_helper.INTENT_DECREASE_TIMER,
|
|
{
|
|
"hours": {"value": 2},
|
|
"minutes": {"value": 3},
|
|
"seconds": {"value": 5}, # remove 1 extra second
|
|
},
|
|
device_id=device.device_id,
|
|
)
|
|
|
|
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
|
|
async with asyncio.timeout(1):
|
|
await mock_client.timer_updated_event.wait()
|
|
timer_updated = mock_client.timer_updated
|
|
assert timer_updated is not None
|
|
assert timer_updated.id == timer_started.id
|
|
assert timer_updated.total_seconds < timer_started.total_seconds
|
|
|
|
# Cancel
|
|
result = await intent_helper.async_handle(
|
|
hass,
|
|
"test",
|
|
intent_helper.INTENT_CANCEL_TIMER,
|
|
{},
|
|
device_id=device.device_id,
|
|
)
|
|
|
|
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
|
|
async with asyncio.timeout(1):
|
|
await mock_client.timer_cancelled_event.wait()
|
|
timer_cancelled = mock_client.timer_cancelled
|
|
assert timer_cancelled is not None
|
|
assert timer_cancelled.id == timer_started.id
|
|
|
|
# Start a new timer
|
|
mock_client.timer_started_event.clear()
|
|
result = await intent_helper.async_handle(
|
|
hass,
|
|
"test",
|
|
intent_helper.INTENT_START_TIMER,
|
|
{
|
|
"name": {"value": "test timer"},
|
|
"minutes": {"value": 1},
|
|
},
|
|
device_id=device.device_id,
|
|
)
|
|
|
|
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
|
|
async with asyncio.timeout(1):
|
|
await mock_client.timer_started_event.wait()
|
|
timer_started = mock_client.timer_started
|
|
assert timer_started is not None
|
|
|
|
# Finished
|
|
result = await intent_helper.async_handle(
|
|
hass,
|
|
"test",
|
|
intent_helper.INTENT_DECREASE_TIMER,
|
|
{
|
|
"minutes": {"value": 1}, # force finish
|
|
},
|
|
device_id=device.device_id,
|
|
)
|
|
|
|
assert result.response_type == intent_helper.IntentResponseType.ACTION_DONE
|
|
async with asyncio.timeout(1):
|
|
await mock_client.timer_finished_event.wait()
|
|
timer_finished = mock_client.timer_finished
|
|
assert timer_finished is not None
|
|
assert timer_finished.id == timer_started.id
|