core/homeassistant/components/mcp_server/http.py

171 lines
6.6 KiB
Python

"""Model Context Protocol transport protocol for Server Sent Events (SSE).
This registers HTTP endpoints that supports SSE as a transport layer
for the Model Context Protocol. There are two HTTP endpoints:
- /mcp_server/sse: The SSE endpoint that is used to establish a session
with the client and glue to the MCP server. This is used to push responses
to the client.
- /mcp_server/messages: The endpoint that is used by the client to send
POST requests with new requests for the MCP server. The request contains
a session identifier. The response to the client is passed over the SSE
session started on the other endpoint.
See https://modelcontextprotocol.io/docs/concepts/transports
"""
import logging
from aiohttp import web
from aiohttp.web_exceptions import HTTPBadRequest, HTTPNotFound
from aiohttp_sse import sse_response
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp import types
from homeassistant.components import conversation
from homeassistant.components.http import KEY_HASS, HomeAssistantView
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import llm
from .const import DOMAIN
from .server import create_server
from .session import Session
from .types import MCPServerConfigEntry
_LOGGER = logging.getLogger(__name__)
SSE_API = f"/{DOMAIN}/sse"
MESSAGES_API = f"/{DOMAIN}/messages/{{session_id}}"
@callback
def async_register(hass: HomeAssistant) -> None:
"""Register the websocket API."""
hass.http.register_view(ModelContextProtocolSSEView())
hass.http.register_view(ModelContextProtocolMessagesView())
def async_get_config_entry(hass: HomeAssistant) -> MCPServerConfigEntry:
"""Get the first enabled MCP server config entry.
The ConfigEntry contains a reference to the actual MCP server used to
serve the Model Context Protocol.
Will raise an HTTP error if the expected configuration is not present.
"""
config_entries: list[MCPServerConfigEntry] = [
config_entry
for config_entry in hass.config_entries.async_entries(DOMAIN)
if config_entry.state == ConfigEntryState.LOADED
]
if not config_entries:
raise HTTPNotFound(text="Model Context Protocol server is not configured")
if len(config_entries) > 1:
raise HTTPNotFound(text="Found multiple Model Context Protocol configurations")
return config_entries[0]
class ModelContextProtocolSSEView(HomeAssistantView):
"""Model Context Protocol SSE endpoint."""
name = f"{DOMAIN}:sse"
url = SSE_API
async def get(self, request: web.Request) -> web.StreamResponse:
"""Process SSE messages for the Model Context Protocol.
This is a long running request for the lifetime of the client session
and is the primary transport layer between the client and server.
Pairs of buffered streams act as a bridge between the transport protocol
(SSE over HTTP views) and the Model Context Protocol. The MCP SDK
manages all protocol details and invokes commands on our MCP server.
"""
hass = request.app[KEY_HASS]
entry = async_get_config_entry(hass)
session_manager = entry.runtime_data
context = llm.LLMContext(
platform=DOMAIN,
context=self.context(request),
user_prompt=None,
language="*",
assistant=conversation.DOMAIN,
device_id=None,
)
llm_api_id = entry.data[CONF_LLM_HASS_API]
server = await create_server(hass, llm_api_id, context)
options = await hass.async_add_executor_job(
server.create_initialization_options # Reads package for version info
)
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async with (
sse_response(request) as response,
session_manager.create(Session(read_stream_writer)) as session_id,
):
session_uri = MESSAGES_API.format(session_id=session_id)
_LOGGER.debug("Sending SSE endpoint: %s", session_uri)
await response.send(session_uri, event="endpoint")
async def sse_reader() -> None:
"""Forward MCP server responses to the client."""
async for message in write_stream_reader:
_LOGGER.debug("Sending SSE message: %s", message)
await response.send(
message.model_dump_json(by_alias=True, exclude_none=True),
event="message",
)
async with anyio.create_task_group() as tg:
tg.start_soon(sse_reader)
await server.run(read_stream, write_stream, options)
return response
class ModelContextProtocolMessagesView(HomeAssistantView):
"""Model Context Protocol messages endpoint."""
name = f"{DOMAIN}:messages"
url = MESSAGES_API
async def post(
self,
request: web.Request,
session_id: str,
) -> web.StreamResponse:
"""Process incoming messages for the Model Context Protocol.
The request passes a session ID which is used to identify the original
SSE connection. This view parses incoming messages from the transport
layer then writes them to the MCP server stream for the session.
"""
hass = request.app[KEY_HASS]
config_entry = async_get_config_entry(hass)
session_manager = config_entry.runtime_data
if (session := session_manager.get(session_id)) is None:
_LOGGER.info("Could not find session ID: '%s'", session_id)
raise HTTPNotFound(text=f"Could not find session ID '{session_id}'")
json_data = await request.json()
try:
message = types.JSONRPCMessage.model_validate(json_data)
except ValueError as err:
_LOGGER.info("Failed to parse message: %s", err)
raise HTTPBadRequest(text="Could not parse message") from err
_LOGGER.debug("Received client message: %s", message)
await session.read_stream_writer.send(message)
return web.Response(status=200)