core/tests/components/conversation/test_session.py

416 lines
14 KiB
Python

"""Test the conversation session."""
from collections.abc import Generator
from datetime import timedelta
from unittest.mock import Mock, patch
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.conversation import ConversationInput, session
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import llm
from homeassistant.util import dt as dt_util
from tests.common import async_fire_time_changed
@pytest.fixture
def mock_conversation_input(hass: HomeAssistant) -> ConversationInput:
"""Return a conversation input instance."""
return ConversationInput(
text="Hello",
context=Context(),
conversation_id=None,
agent_id="mock-agent-id",
device_id=None,
language="en",
)
@pytest.fixture
def mock_ulid() -> Generator[Mock]:
"""Mock the ulid library."""
with patch("homeassistant.util.ulid.ulid_now") as mock_ulid_now:
mock_ulid_now.return_value = "mock-ulid"
yield mock_ulid_now
@pytest.mark.parametrize(
("start_id", "given_id"),
[
(None, "mock-ulid"),
# This ULID is not known as a session
("01JHXE0952TSJCFJZ869AW6HMD", "mock-ulid"),
("not-a-ulid", "not-a-ulid"),
],
)
async def test_conversation_id(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
mock_ulid: Mock,
start_id: str | None,
given_id: str,
) -> None:
"""Test conversation ID generation."""
mock_conversation_input.conversation_id = start_id
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert chat_session.conversation_id == given_id
async def test_cleanup(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Mock cleanup of the conversation session."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert len(chat_session.messages) == 2
conversation_id = chat_session.conversation_id
# Generate session entry.
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
# Because we didn't add a message to the session in the last block,
# the conversation was not be persisted and we get a new ID
assert chat_session.conversation_id != conversation_id
conversation_id = chat_session.conversation_id
chat_session.async_add_message(
session.ChatMessage(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
)
assert len(chat_session.messages) == 3
# Reuse conversation ID to ensure we can chat with same session
mock_conversation_input.conversation_id = conversation_id
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert len(chat_session.messages) == 4
assert chat_session.conversation_id == conversation_id
# Set the last updated to be older than the timeout
hass.data[session.DATA_CHAT_HISTORY][conversation_id].last_updated = (
dt_util.utcnow() + session.CONVERSATION_TIMEOUT
)
async_fire_time_changed(
hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT + timedelta(seconds=1)
)
# Should not be cleaned up, but it should have scheduled another cleanup
mock_conversation_input.conversation_id = conversation_id
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert len(chat_session.messages) == 4
assert chat_session.conversation_id == conversation_id
async_fire_time_changed(
hass, dt_util.utcnow() + session.CONVERSATION_TIMEOUT * 2 + timedelta(seconds=1)
)
# It should be cleaned up now and we start a new conversation
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert chat_session.conversation_id != conversation_id
assert len(chat_session.messages) == 2
def test_chat_message() -> None:
"""Test chat message."""
with pytest.raises(ValueError):
session.ChatMessage(role="native", agent_id=None, content="", native=None)
async def test_add_message(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
assert len(chat_session.messages) == 2
with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(role="system", agent_id=None, content="")
)
# No 2 user messages in a row
assert chat_session.messages[1].role == "user"
with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(role="user", agent_id=None, content="")
)
# No 2 assistant messages in a row
chat_session.async_add_message(
session.ChatMessage(role="assistant", agent_id=None, content="")
)
assert len(chat_session.messages) == 3
assert chat_session.messages[-1].role == "assistant"
with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(role="assistant", agent_id=None, content="")
)
async def test_message_filtering(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test filtering of messages."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
messages = chat_session.async_get_messages(agent_id=None)
assert len(messages) == 2
assert messages[0] == session.ChatMessage(
role="system",
agent_id=None,
content="",
)
assert messages[1] == session.ChatMessage(
role="user",
agent_id=mock_conversation_input.agent_id,
content=mock_conversation_input.text,
)
# Cannot add a second user message in a row
with pytest.raises(ValueError):
chat_session.async_add_message(
session.ChatMessage(
role="user",
agent_id="mock-agent-id",
content="Hey!",
)
)
chat_session.async_add_message(
session.ChatMessage(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
native="assistant-reply-native",
)
)
# Different agent, will be filtered out.
chat_session.async_add_message(
session.ChatMessage(
role="native", agent_id="another-mock-agent-id", content="", native=1
)
)
chat_session.async_add_message(
session.ChatMessage(
role="native", agent_id="mock-agent-id", content="", native=1
)
)
assert len(chat_session.messages) == 5
messages = chat_session.async_get_messages(agent_id="mock-agent-id")
assert len(messages) == 4
assert messages[2] == session.ChatMessage(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
native="assistant-reply-native",
)
assert messages[3] == session.ChatMessage(
role="native", agent_id="mock-agent-id", content="", native=1
)
async def test_llm_api(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
) -> None:
"""Test when we reference an LLM API."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="assist",
user_llm_prompt=None,
)
assert isinstance(chat_session.llm_api, llm.APIInstance)
assert chat_session.llm_api.api.id == "assist"
async def test_unknown_llm_api(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
snapshot: SnapshotAssertion,
) -> None:
"""Test when we reference an LLM API that does not exists."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
with pytest.raises(session.ConverseError) as exc_info:
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api="unknown-api",
user_llm_prompt=None,
)
assert str(exc_info.value) == "Error getting LLM API unknown-api"
assert exc_info.value.as_conversation_result().as_dict() == snapshot
async def test_template_error(
hass: HomeAssistant,
mock_conversation_input: ConversationInput,
snapshot: SnapshotAssertion,
) -> None:
"""Test that template error handling works."""
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
with pytest.raises(session.ConverseError) as exc_info:
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt="{{ invalid_syntax",
)
assert str(exc_info.value) == "Error rendering prompt"
assert exc_info.value.as_conversation_result().as_dict() == snapshot
async def test_template_variables(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test that template variables work."""
mock_user = Mock()
mock_user.id = "12345"
mock_user.name = "Test User"
mock_conversation_input.context = Context(user_id=mock_user.id)
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
with patch(
"homeassistant.auth.AuthManager.async_get_user", return_value=mock_user
):
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt=(
"The instance name is {{ ha_name }}. "
"The user name is {{ user_name }}. "
"The user id is {{ llm_context.context.user_id }}."
"The calling platform is {{ llm_context.platform }}."
),
)
assert chat_session.user_name == "Test User"
assert "The instance name is test home." in chat_session.messages[0].content
assert "The user name is Test User." in chat_session.messages[0].content
assert "The user id is 12345." in chat_session.messages[0].content
assert "The calling platform is test." in chat_session.messages[0].content
async def test_extra_systen_prompt(
hass: HomeAssistant, mock_conversation_input: ConversationInput
) -> None:
"""Test that extra system prompt works."""
extra_system_prompt = "Garage door cover.garage_door has been left open for 30 minutes. We asked the user if they want to close it."
extra_system_prompt2 = (
"User person.paulus came home. Asked him what he wants to do."
)
mock_conversation_input.extra_system_prompt = extra_system_prompt
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt=None,
)
chat_session.async_add_message(
session.ChatMessage(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
)
assert chat_session.extra_system_prompt == extra_system_prompt
assert chat_session.messages[0].content.endswith(extra_system_prompt)
# Verify that follow-up conversations with no system prompt take previous one
mock_conversation_input.conversation_id = chat_session.conversation_id
mock_conversation_input.extra_system_prompt = None
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt=None,
)
assert chat_session.extra_system_prompt == extra_system_prompt
assert chat_session.messages[0].content.endswith(extra_system_prompt)
# Verify that we take new system prompts
mock_conversation_input.extra_system_prompt = extra_system_prompt2
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt=None,
)
chat_session.async_add_message(
session.ChatMessage(
role="assistant",
agent_id="mock-agent-id",
content="Hey!",
)
)
assert chat_session.extra_system_prompt == extra_system_prompt2
assert chat_session.messages[0].content.endswith(extra_system_prompt2)
assert extra_system_prompt not in chat_session.messages[0].content
# Verify that follow-up conversations with no system prompt take previous one
mock_conversation_input.extra_system_prompt = None
async with session.async_get_chat_session(
hass, mock_conversation_input
) as chat_session:
await chat_session.async_update_llm_data(
conversing_domain="test",
user_input=mock_conversation_input,
user_llm_hass_api=None,
user_llm_prompt=None,
)
assert chat_session.extra_system_prompt == extra_system_prompt2
assert chat_session.messages[0].content.endswith(extra_system_prompt2)