core/homeassistant/components/mcp/coordinator.py

172 lines
5.7 KiB
Python

"""Types for the Model Context Protocol integration."""
import asyncio
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import datetime
import logging
import httpx
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
import voluptuous as vol
from voluptuous_openapi import convert_to_voluptuous
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from homeassistant.util.json import JsonObjectType
from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
UPDATE_INTERVAL = datetime.timedelta(minutes=30)
TIMEOUT = 10
@asynccontextmanager
async def mcp_client(url: str) -> AsyncGenerator[ClientSession]:
"""Create a server-sent event MCP client.
This is an asynccontext manager that exists to wrap other async context managers
so that the coordinator has a single object to manage.
"""
try:
async with sse_client(url=url) as streams, ClientSession(*streams) as session:
await session.initialize()
yield session
except ExceptionGroup as err:
raise err.exceptions[0] from err
class ModelContextProtocolTool(llm.Tool):
"""A Tool exposed over the Model Context Protocol."""
def __init__(
self,
name: str,
description: str | None,
parameters: vol.Schema,
session: ClientSession,
) -> None:
"""Initialize the tool."""
self.name = name
self.description = description
self.parameters = parameters
self.session = session
async def async_call(
self,
hass: HomeAssistant,
tool_input: llm.ToolInput,
llm_context: llm.LLMContext,
) -> JsonObjectType:
"""Call the tool."""
try:
result = await self.session.call_tool(
tool_input.tool_name, tool_input.tool_args
)
except httpx.HTTPStatusError as error:
raise HomeAssistantError(f"Error when calling tool: {error}") from error
return result.model_dump(exclude_unset=True, exclude_none=True)
class ModelContextProtocolCoordinator(DataUpdateCoordinator[list[llm.Tool]]):
"""Define an object to hold MCP data."""
config_entry: ConfigEntry
_session: ClientSession | None = None
_setup_error: Exception | None = None
def __init__(self, hass: HomeAssistant, config_entry: ConfigEntry) -> None:
"""Initialize ModelContextProtocolCoordinator."""
super().__init__(
hass,
logger=_LOGGER,
name=DOMAIN,
config_entry=config_entry,
update_interval=UPDATE_INTERVAL,
)
self._stop = asyncio.Event()
async def _async_setup(self) -> None:
"""Set up the client connection."""
connected = asyncio.Event()
stop = asyncio.Event()
self.config_entry.async_create_background_task(
self.hass, self._connect(connected, stop), "mcp-client"
)
try:
async with asyncio.timeout(TIMEOUT):
await connected.wait()
self._stop = stop
finally:
if self._setup_error is not None:
raise self._setup_error
async def _connect(self, connected: asyncio.Event, stop: asyncio.Event) -> None:
"""Create a server-sent event MCP client."""
url = self.config_entry.data[CONF_URL]
try:
async with (
sse_client(url=url) as streams,
ClientSession(*streams) as session,
):
await session.initialize()
self._session = session
connected.set()
await stop.wait()
except httpx.HTTPStatusError as err:
self._setup_error = err
_LOGGER.debug("Error connecting to MCP server: %s", err)
raise UpdateFailed(f"Error connecting to MCP server: {err}") from err
except ExceptionGroup as err:
self._setup_error = err.exceptions[0]
_LOGGER.debug("Error connecting to MCP server: %s", err)
raise UpdateFailed(
"Error connecting to MCP server: {err.exceptions[0]}"
) from err.exceptions[0]
finally:
self._session = None
async def close(self) -> None:
"""Close the client connection."""
if self._stop is not None:
self._stop.set()
async def _async_update_data(self) -> list[llm.Tool]:
"""Fetch data from API endpoint.
This is the place to pre-process the data to lookup tables
so entities can quickly look up their data.
"""
if self._session is None:
raise UpdateFailed("No session available")
try:
result = await self._session.list_tools()
except httpx.HTTPError as err:
raise UpdateFailed(f"Error communicating with API: {err}") from err
_LOGGER.debug("Received tools: %s", result.tools)
tools: list[llm.Tool] = []
for tool in result.tools:
try:
parameters = convert_to_voluptuous(tool.inputSchema)
except Exception as err:
raise UpdateFailed(
f"Error converting schema {err}: {tool.inputSchema}"
) from err
tools.append(
ModelContextProtocolTool(
tool.name,
tool.description,
parameters,
self._session,
)
)
return tools