core/homeassistant/components/conversation/chat_log.py

444 lines
15 KiB
Python

"""Conversation chat log."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator, AsyncIterable, Callable, Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import asdict, dataclass, field, replace
import logging
from typing import Literal, TypedDict
import voluptuous as vol
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, TemplateError
from homeassistant.helpers import chat_session, intent, llm, template
from homeassistant.util.hass_dict import HassKey
from homeassistant.util.json import JsonObjectType
from . import trace
from .const import DOMAIN
from .models import ConversationInput, ConversationResult
DATA_CHAT_LOGS: HassKey[dict[str, ChatLog]] = HassKey("conversation_chat_logs")
LOGGER = logging.getLogger(__name__)
current_chat_log: ContextVar[ChatLog | None] = ContextVar(
"current_chat_log", default=None
)
@contextmanager
def async_get_chat_log(
hass: HomeAssistant,
session: chat_session.ChatSession,
user_input: ConversationInput | None = None,
*,
chat_log_delta_listener: Callable[[ChatLog, dict], None] | None = None,
) -> Generator[ChatLog]:
"""Return chat log for a specific chat session."""
# If a chat log is already active and it's the requested conversation ID,
# return that. We won't update the last updated time in this case.
if (
chat_log := current_chat_log.get()
) and chat_log.conversation_id == session.conversation_id:
if chat_log_delta_listener is not None:
raise RuntimeError(
"Cannot attach chat log delta listener unless initial caller"
)
if user_input is not None:
chat_log.async_add_user_content(UserContent(content=user_input.text))
yield chat_log
return
all_chat_logs = hass.data.get(DATA_CHAT_LOGS)
if all_chat_logs is None:
all_chat_logs = {}
hass.data[DATA_CHAT_LOGS] = all_chat_logs
if chat_log := all_chat_logs.get(session.conversation_id):
chat_log = replace(chat_log, content=chat_log.content.copy())
else:
chat_log = ChatLog(hass, session.conversation_id)
if chat_log_delta_listener:
chat_log.delta_listener = chat_log_delta_listener
if user_input is not None:
chat_log.async_add_user_content(UserContent(content=user_input.text))
last_message = chat_log.content[-1]
token = current_chat_log.set(chat_log)
yield chat_log
current_chat_log.reset(token)
if chat_log.content[-1] is last_message:
LOGGER.debug(
"Chat Log opened but no assistant message was added, ignoring update"
)
return
if session.conversation_id not in all_chat_logs:
@callback
def do_cleanup() -> None:
"""Handle cleanup."""
all_chat_logs.pop(session.conversation_id)
session.async_on_cleanup(do_cleanup)
if chat_log_delta_listener:
chat_log.delta_listener = None
all_chat_logs[session.conversation_id] = chat_log
class ConverseError(HomeAssistantError):
"""Error during initialization of conversation.
Will not be stored in the history.
"""
def __init__(
self, message: str, conversation_id: str, response: intent.IntentResponse
) -> None:
"""Initialize the error."""
super().__init__(message)
self.conversation_id = conversation_id
self.response = response
def as_conversation_result(self) -> ConversationResult:
"""Return the error as a conversation result."""
return ConversationResult(
response=self.response,
conversation_id=self.conversation_id,
)
@dataclass(frozen=True)
class SystemContent:
"""Base class for chat messages."""
role: str = field(init=False, default="system")
content: str
@dataclass(frozen=True)
class UserContent:
"""Assistant content."""
role: str = field(init=False, default="user")
content: str
@dataclass(frozen=True)
class AssistantContent:
"""Assistant content."""
role: str = field(init=False, default="assistant")
agent_id: str
content: str | None = None
tool_calls: list[llm.ToolInput] | None = None
@dataclass(frozen=True)
class ToolResultContent:
"""Tool result content."""
role: str = field(init=False, default="tool_result")
agent_id: str
tool_call_id: str
tool_name: str
tool_result: JsonObjectType
type Content = SystemContent | UserContent | AssistantContent | ToolResultContent
class AssistantContentDeltaDict(TypedDict, total=False):
"""Partial content to define an AssistantContent."""
role: Literal["assistant"]
content: str | None
tool_calls: list[llm.ToolInput] | None
@dataclass
class ChatLog:
"""Class holding the chat history of a specific conversation."""
hass: HomeAssistant
conversation_id: str
content: list[Content] = field(default_factory=lambda: [SystemContent(content="")])
extra_system_prompt: str | None = None
llm_api: llm.APIInstance | None = None
delta_listener: Callable[[ChatLog, dict], None] | None = None
@property
def unresponded_tool_results(self) -> bool:
"""Return if there are unresponded tool results."""
return self.content[-1].role == "tool_result"
@callback
def async_add_user_content(self, content: UserContent) -> None:
"""Add user content to the log."""
LOGGER.debug("Adding user content: %s", content)
self.content.append(content)
@callback
def async_add_assistant_content_without_tools(
self, content: AssistantContent
) -> None:
"""Add assistant content to the log."""
LOGGER.debug("Adding assistant content: %s", content)
if content.tool_calls is not None:
raise ValueError("Tool calls not allowed")
self.content.append(content)
async def async_add_assistant_content(
self,
content: AssistantContent,
/,
tool_call_tasks: dict[str, asyncio.Task] | None = None,
) -> AsyncGenerator[ToolResultContent]:
"""Add assistant content and execute tool calls.
tool_call_tasks can contains tasks for tool calls that are already in progress.
This method is an async generator and will yield the tool results as they come in.
"""
LOGGER.debug("Adding assistant content: %s", content)
self.content.append(content)
if content.tool_calls is None:
return
if self.llm_api is None:
raise ValueError("No LLM API configured")
if tool_call_tasks is None:
tool_call_tasks = {}
for tool_input in content.tool_calls:
if tool_input.id not in tool_call_tasks:
tool_call_tasks[tool_input.id] = self.hass.async_create_task(
self.llm_api.async_call_tool(tool_input),
name=f"llm_tool_{tool_input.id}",
)
for tool_input in content.tool_calls:
LOGGER.debug(
"Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args
)
try:
tool_result = await tool_call_tasks[tool_input.id]
except (HomeAssistantError, vol.Invalid) as e:
tool_result = {"error": type(e).__name__}
if str(e):
tool_result["error_text"] = str(e)
LOGGER.debug("Tool response: %s", tool_result)
response_content = ToolResultContent(
agent_id=content.agent_id,
tool_call_id=tool_input.id,
tool_name=tool_input.tool_name,
tool_result=tool_result,
)
self.content.append(response_content)
yield response_content
async def async_add_delta_content_stream(
self, agent_id: str, stream: AsyncIterable[AssistantContentDeltaDict]
) -> AsyncGenerator[AssistantContent | ToolResultContent]:
"""Stream content into the chat log.
Returns a generator with all content that was added to the chat log.
stream iterates over dictionaries with optional keys role, content and tool_calls.
When a delta contains a role key, the current message is considered complete and
a new message is started.
The keys content and tool_calls will be concatenated if they appear multiple times.
"""
current_content = ""
current_tool_calls: list[llm.ToolInput] = []
tool_call_tasks: dict[str, asyncio.Task] = {}
async for delta in stream:
LOGGER.debug("Received delta: %s", delta)
# Indicates update to current message
if "role" not in delta:
if delta_content := delta.get("content"):
current_content += delta_content
if delta_tool_calls := delta.get("tool_calls"):
if self.llm_api is None:
raise ValueError("No LLM API configured")
current_tool_calls += delta_tool_calls
# Start processing the tool calls as soon as we know about them
for tool_call in delta_tool_calls:
tool_call_tasks[tool_call.id] = self.hass.async_create_task(
self.llm_api.async_call_tool(tool_call),
name=f"llm_tool_{tool_call.id}",
)
if self.delta_listener:
self.delta_listener(self, delta) # type: ignore[arg-type]
continue
# Starting a new message
if delta["role"] != "assistant":
raise ValueError(f"Only assistant role expected. Got {delta['role']}")
# Yield the previous message if it has content
if current_content or current_tool_calls:
content = AssistantContent(
agent_id=agent_id,
content=current_content or None,
tool_calls=current_tool_calls or None,
)
yield content
async for tool_result in self.async_add_assistant_content(
content, tool_call_tasks=tool_call_tasks
):
yield tool_result
if self.delta_listener:
self.delta_listener(self, asdict(tool_result))
current_content = delta.get("content") or ""
current_tool_calls = delta.get("tool_calls") or []
if self.delta_listener:
self.delta_listener(self, delta) # type: ignore[arg-type]
if current_content or current_tool_calls:
content = AssistantContent(
agent_id=agent_id,
content=current_content or None,
tool_calls=current_tool_calls or None,
)
yield content
async for tool_result in self.async_add_assistant_content(
content, tool_call_tasks=tool_call_tasks
):
yield tool_result
if self.delta_listener:
self.delta_listener(self, asdict(tool_result))
async def async_update_llm_data(
self,
conversing_domain: str,
user_input: ConversationInput,
user_llm_hass_api: str | None = None,
user_llm_prompt: str | None = None,
) -> None:
"""Set the LLM system prompt."""
llm_context = llm.LLMContext(
platform=conversing_domain,
context=user_input.context,
user_prompt=user_input.text,
language=user_input.language,
assistant=DOMAIN,
device_id=user_input.device_id,
)
llm_api: llm.APIInstance | None = None
if user_llm_hass_api:
try:
llm_api = await llm.async_get_api(
self.hass,
user_llm_hass_api,
llm_context,
)
except HomeAssistantError as err:
LOGGER.error(
"Error getting LLM API %s for %s: %s",
user_llm_hass_api,
conversing_domain,
err,
)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Error preparing LLM API",
)
raise ConverseError(
f"Error getting LLM API {user_llm_hass_api}",
conversation_id=self.conversation_id,
response=intent_response,
) from err
user_name: str | None = None
if (
user_input.context
and user_input.context.user_id
and (
user := await self.hass.auth.async_get_user(user_input.context.user_id)
)
):
user_name = user.name
try:
prompt_parts = [
template.Template(
llm.BASE_PROMPT
+ (user_llm_prompt or llm.DEFAULT_INSTRUCTIONS_PROMPT),
self.hass,
).async_render(
{
"ha_name": self.hass.config.location_name,
"user_name": user_name,
"llm_context": llm_context,
},
parse_result=False,
)
]
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
"Sorry, I had a problem with my template",
)
raise ConverseError(
"Error rendering prompt",
conversation_id=self.conversation_id,
response=intent_response,
) from err
if llm_api:
prompt_parts.append(llm_api.api_prompt)
if extra_system_prompt := (
# Take new system prompt if one was given
user_input.extra_system_prompt or self.extra_system_prompt
):
prompt_parts.append(extra_system_prompt)
prompt = "\n".join(prompt_parts)
self.llm_api = llm_api
self.extra_system_prompt = extra_system_prompt
self.content[0] = SystemContent(content=prompt)
LOGGER.debug("Prompt: %s", self.content)
LOGGER.debug("Tools: %s", self.llm_api.tools if self.llm_api else None)
trace.async_conversation_trace_append(
trace.ConversationTraceEventType.AGENT_DETAIL,
{
"messages": self.content,
"tools": self.llm_api.tools if self.llm_api else None,
},
)