mirror of https://github.com/home-assistant/core
329 lines
10 KiB
Python
329 lines
10 KiB
Python
"""Conversation history."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass, field, replace
|
|
from datetime import datetime, timedelta
|
|
import logging
|
|
from typing import Literal
|
|
|
|
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
|
|
from homeassistant.core import (
|
|
CALLBACK_TYPE,
|
|
Event,
|
|
HassJob,
|
|
HassJobType,
|
|
HomeAssistant,
|
|
callback,
|
|
)
|
|
from homeassistant.exceptions import HomeAssistantError, TemplateError
|
|
from homeassistant.helpers import intent, llm, template
|
|
from homeassistant.helpers.event import async_call_later
|
|
from homeassistant.util import dt as dt_util, ulid as ulid_util
|
|
from homeassistant.util.hass_dict import HassKey
|
|
|
|
from .const import DOMAIN
|
|
from .models import ConversationInput, ConversationResult
|
|
|
|
DATA_CHAT_HISTORY: HassKey[dict[str, ChatSession]] = HassKey(
|
|
"conversation_chat_session"
|
|
)
|
|
DATA_CHAT_HISTORY_CLEANUP: HassKey[SessionCleanup] = HassKey(
|
|
"conversation_chat_session_cleanup"
|
|
)
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
CONVERSATION_TIMEOUT = timedelta(minutes=5)
|
|
|
|
|
|
class SessionCleanup:
|
|
"""Helper to clean up the history."""
|
|
|
|
unsub: CALLBACK_TYPE | None = None
|
|
|
|
def __init__(self, hass: HomeAssistant) -> None:
|
|
"""Initialize the history cleanup."""
|
|
self.hass = hass
|
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self._on_hass_stop)
|
|
self.cleanup_job = HassJob(
|
|
self._cleanup, "conversation_history_cleanup", job_type=HassJobType.Callback
|
|
)
|
|
|
|
@callback
|
|
def schedule(self) -> None:
|
|
"""Schedule the cleanup."""
|
|
if self.unsub:
|
|
return
|
|
self.unsub = async_call_later(
|
|
self.hass,
|
|
CONVERSATION_TIMEOUT.total_seconds() + 1,
|
|
self.cleanup_job,
|
|
)
|
|
|
|
@callback
|
|
def _on_hass_stop(self, event: Event) -> None:
|
|
"""Cancel the cleanup on shutdown."""
|
|
if self.unsub:
|
|
self.unsub()
|
|
self.unsub = None
|
|
|
|
@callback
|
|
def _cleanup(self, now: datetime) -> None:
|
|
"""Clean up the history and schedule follow-up if necessary."""
|
|
self.unsub = None
|
|
all_history = self.hass.data[DATA_CHAT_HISTORY]
|
|
|
|
# We mutate original object because current commands could be
|
|
# yielding history based on it.
|
|
for conversation_id, history in list(all_history.items()):
|
|
if history.last_updated + CONVERSATION_TIMEOUT < now:
|
|
del all_history[conversation_id]
|
|
|
|
# Still conversations left, check again in timeout time.
|
|
if all_history:
|
|
self.schedule()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def async_get_chat_session(
|
|
hass: HomeAssistant,
|
|
user_input: ConversationInput,
|
|
) -> AsyncGenerator[ChatSession]:
|
|
"""Return chat session."""
|
|
all_history = hass.data.get(DATA_CHAT_HISTORY)
|
|
if all_history is None:
|
|
all_history = {}
|
|
hass.data[DATA_CHAT_HISTORY] = all_history
|
|
hass.data[DATA_CHAT_HISTORY_CLEANUP] = SessionCleanup(hass)
|
|
|
|
history: ChatSession | None = None
|
|
|
|
if user_input.conversation_id is None:
|
|
conversation_id = ulid_util.ulid_now()
|
|
|
|
elif history := all_history.get(user_input.conversation_id):
|
|
conversation_id = user_input.conversation_id
|
|
|
|
else:
|
|
# Conversation IDs are ULIDs. We generate a new one if not provided.
|
|
# If an old OLID is passed in, we will generate a new one to indicate
|
|
# a new conversation was started. If the user picks their own, they
|
|
# want to track a conversation and we respect it.
|
|
try:
|
|
ulid_util.ulid_to_bytes(user_input.conversation_id)
|
|
conversation_id = ulid_util.ulid_now()
|
|
except ValueError:
|
|
conversation_id = user_input.conversation_id
|
|
|
|
if history:
|
|
history = replace(history, messages=history.messages.copy())
|
|
else:
|
|
history = ChatSession(hass, conversation_id)
|
|
|
|
message: ChatMessage = ChatMessage(
|
|
role="user",
|
|
agent_id=user_input.agent_id,
|
|
content=user_input.text,
|
|
)
|
|
history.async_add_message(message)
|
|
|
|
yield history
|
|
|
|
if history.messages[-1] is message:
|
|
LOGGER.debug(
|
|
"History opened but no assistant message was added, ignoring update"
|
|
)
|
|
return
|
|
|
|
history.last_updated = dt_util.utcnow()
|
|
all_history[conversation_id] = history
|
|
hass.data[DATA_CHAT_HISTORY_CLEANUP].schedule()
|
|
|
|
|
|
class ConverseError(HomeAssistantError):
|
|
"""Error during initialization of conversation.
|
|
|
|
Will not be stored in the history.
|
|
"""
|
|
|
|
def __init__(
|
|
self, message: str, conversation_id: str, response: intent.IntentResponse
|
|
) -> None:
|
|
"""Initialize the error."""
|
|
super().__init__(message)
|
|
self.conversation_id = conversation_id
|
|
self.response = response
|
|
|
|
def as_conversation_result(self) -> ConversationResult:
|
|
"""Return the error as a conversation result."""
|
|
return ConversationResult(
|
|
response=self.response,
|
|
conversation_id=self.conversation_id,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ChatMessage[_NativeT]:
|
|
"""Base class for chat messages.
|
|
|
|
When role is native, the content is to be ignored and message
|
|
is only meant for storing the native object.
|
|
"""
|
|
|
|
role: Literal["system", "assistant", "user", "native"]
|
|
agent_id: str | None
|
|
content: str
|
|
native: _NativeT | None = field(default=None)
|
|
|
|
# Validate in post-init that if role is native, there is no content and a native object exists
|
|
def __post_init__(self) -> None:
|
|
"""Validate native message."""
|
|
if self.role == "native" and self.native is None:
|
|
raise ValueError("Native message must have a native object")
|
|
|
|
|
|
@dataclass
|
|
class ChatSession[_NativeT]:
|
|
"""Class holding all information for a specific conversation."""
|
|
|
|
hass: HomeAssistant
|
|
conversation_id: str
|
|
user_name: str | None = None
|
|
messages: list[ChatMessage[_NativeT]] = field(
|
|
default_factory=lambda: [ChatMessage(role="system", agent_id=None, content="")]
|
|
)
|
|
extra_system_prompt: str | None = None
|
|
llm_api: llm.APIInstance | None = None
|
|
last_updated: datetime = field(default_factory=dt_util.utcnow)
|
|
|
|
@callback
|
|
def async_add_message(self, message: ChatMessage[_NativeT]) -> None:
|
|
"""Process intent."""
|
|
if message.role == "system":
|
|
raise ValueError("Cannot add system messages to history")
|
|
if message.role != "native" and self.messages[-1].role == message.role:
|
|
raise ValueError("Cannot add two assistant or user messages in a row")
|
|
|
|
self.messages.append(message)
|
|
|
|
@callback
|
|
def async_get_messages(self, agent_id: str | None) -> list[ChatMessage[_NativeT]]:
|
|
"""Get messages for a specific agent ID.
|
|
|
|
This will filter out any native message tied to other agent IDs.
|
|
It can still include assistant/user messages generated by other agents.
|
|
"""
|
|
return [
|
|
message
|
|
for message in self.messages
|
|
if message.role != "native" or message.agent_id == agent_id
|
|
]
|
|
|
|
async def async_update_llm_data(
|
|
self,
|
|
conversing_domain: str,
|
|
user_input: ConversationInput,
|
|
user_llm_hass_api: str | None = None,
|
|
user_llm_prompt: str | None = None,
|
|
) -> None:
|
|
"""Set the LLM system prompt."""
|
|
llm_context = llm.LLMContext(
|
|
platform=conversing_domain,
|
|
context=user_input.context,
|
|
user_prompt=user_input.text,
|
|
language=user_input.language,
|
|
assistant=DOMAIN,
|
|
device_id=user_input.device_id,
|
|
)
|
|
|
|
llm_api: llm.APIInstance | None = None
|
|
|
|
if user_llm_hass_api:
|
|
try:
|
|
llm_api = await llm.async_get_api(
|
|
self.hass,
|
|
user_llm_hass_api,
|
|
llm_context,
|
|
)
|
|
except HomeAssistantError as err:
|
|
LOGGER.error(
|
|
"Error getting LLM API %s for %s: %s",
|
|
user_llm_hass_api,
|
|
conversing_domain,
|
|
err,
|
|
)
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
"Error preparing LLM API",
|
|
)
|
|
raise ConverseError(
|
|
f"Error getting LLM API {user_llm_hass_api}",
|
|
conversation_id=self.conversation_id,
|
|
response=intent_response,
|
|
) from err
|
|
|
|
user_name: str | None = None
|
|
|
|
if (
|
|
user_input.context
|
|
and user_input.context.user_id
|
|
and (
|
|
user := await self.hass.auth.async_get_user(user_input.context.user_id)
|
|
)
|
|
):
|
|
user_name = user.name
|
|
|
|
try:
|
|
prompt_parts = [
|
|
template.Template(
|
|
llm.BASE_PROMPT
|
|
+ (user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT),
|
|
self.hass,
|
|
).async_render(
|
|
{
|
|
"ha_name": self.hass.config.location_name,
|
|
"user_name": user_name,
|
|
"llm_context": llm_context,
|
|
},
|
|
parse_result=False,
|
|
)
|
|
]
|
|
|
|
except TemplateError as err:
|
|
LOGGER.error("Error rendering prompt: %s", err)
|
|
intent_response = intent.IntentResponse(language=user_input.language)
|
|
intent_response.async_set_error(
|
|
intent.IntentResponseErrorCode.UNKNOWN,
|
|
"Sorry, I had a problem with my template",
|
|
)
|
|
raise ConverseError(
|
|
"Error rendering prompt",
|
|
conversation_id=self.conversation_id,
|
|
response=intent_response,
|
|
) from err
|
|
|
|
if llm_api:
|
|
prompt_parts.append(llm_api.api_prompt)
|
|
|
|
extra_system_prompt = (
|
|
# Take new system prompt if one was given
|
|
user_input.extra_system_prompt or self.extra_system_prompt
|
|
)
|
|
|
|
if extra_system_prompt:
|
|
prompt_parts.append(extra_system_prompt)
|
|
|
|
prompt = "\n".join(prompt_parts)
|
|
|
|
self.llm_api = llm_api
|
|
self.user_name = user_name
|
|
self.extra_system_prompt = extra_system_prompt
|
|
self.messages[0] = ChatMessage(
|
|
role="system",
|
|
agent_id=user_input.agent_id,
|
|
content=prompt,
|
|
)
|