mirror of https://github.com/home-assistant/core
171 lines
6.6 KiB
Python
171 lines
6.6 KiB
Python
"""Model Context Protocol transport protocol for Server Sent Events (SSE).
|
|
|
|
This registers HTTP endpoints that supports SSE as a transport layer
|
|
for the Model Context Protocol. There are two HTTP endpoints:
|
|
|
|
- /mcp_server/sse: The SSE endpoint that is used to establish a session
|
|
with the client and glue to the MCP server. This is used to push responses
|
|
to the client.
|
|
- /mcp_server/messages: The endpoint that is used by the client to send
|
|
POST requests with new requests for the MCP server. The request contains
|
|
a session identifier. The response to the client is passed over the SSE
|
|
session started on the other endpoint.
|
|
|
|
See https://modelcontextprotocol.io/docs/concepts/transports
|
|
"""
|
|
|
|
import logging
|
|
|
|
from aiohttp import web
|
|
from aiohttp.web_exceptions import HTTPBadRequest, HTTPNotFound
|
|
from aiohttp_sse import sse_response
|
|
import anyio
|
|
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
from mcp import types
|
|
|
|
from homeassistant.components import conversation
|
|
from homeassistant.components.http import KEY_HASS, HomeAssistantView
|
|
from homeassistant.config_entries import ConfigEntryState
|
|
from homeassistant.const import CONF_LLM_HASS_API
|
|
from homeassistant.core import HomeAssistant, callback
|
|
from homeassistant.helpers import llm
|
|
|
|
from .const import DOMAIN
|
|
from .server import create_server
|
|
from .session import Session
|
|
from .types import MCPServerConfigEntry
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
SSE_API = f"/{DOMAIN}/sse"
|
|
MESSAGES_API = f"/{DOMAIN}/messages/{{session_id}}"
|
|
|
|
|
|
@callback
|
|
def async_register(hass: HomeAssistant) -> None:
|
|
"""Register the websocket API."""
|
|
hass.http.register_view(ModelContextProtocolSSEView())
|
|
hass.http.register_view(ModelContextProtocolMessagesView())
|
|
|
|
|
|
def async_get_config_entry(hass: HomeAssistant) -> MCPServerConfigEntry:
|
|
"""Get the first enabled MCP server config entry.
|
|
|
|
The ConfigEntry contains a reference to the actual MCP server used to
|
|
serve the Model Context Protocol.
|
|
|
|
Will raise an HTTP error if the expected configuration is not present.
|
|
"""
|
|
config_entries: list[MCPServerConfigEntry] = [
|
|
config_entry
|
|
for config_entry in hass.config_entries.async_entries(DOMAIN)
|
|
if config_entry.state == ConfigEntryState.LOADED
|
|
]
|
|
if not config_entries:
|
|
raise HTTPNotFound(text="Model Context Protocol server is not configured")
|
|
if len(config_entries) > 1:
|
|
raise HTTPNotFound(text="Found multiple Model Context Protocol configurations")
|
|
return config_entries[0]
|
|
|
|
|
|
class ModelContextProtocolSSEView(HomeAssistantView):
|
|
"""Model Context Protocol SSE endpoint."""
|
|
|
|
name = f"{DOMAIN}:sse"
|
|
url = SSE_API
|
|
|
|
async def get(self, request: web.Request) -> web.StreamResponse:
|
|
"""Process SSE messages for the Model Context Protocol.
|
|
|
|
This is a long running request for the lifetime of the client session
|
|
and is the primary transport layer between the client and server.
|
|
|
|
Pairs of buffered streams act as a bridge between the transport protocol
|
|
(SSE over HTTP views) and the Model Context Protocol. The MCP SDK
|
|
manages all protocol details and invokes commands on our MCP server.
|
|
"""
|
|
hass = request.app[KEY_HASS]
|
|
entry = async_get_config_entry(hass)
|
|
session_manager = entry.runtime_data
|
|
|
|
context = llm.LLMContext(
|
|
platform=DOMAIN,
|
|
context=self.context(request),
|
|
user_prompt=None,
|
|
language="*",
|
|
assistant=conversation.DOMAIN,
|
|
device_id=None,
|
|
)
|
|
llm_api_id = entry.data[CONF_LLM_HASS_API]
|
|
server = await create_server(hass, llm_api_id, context)
|
|
options = await hass.async_add_executor_job(
|
|
server.create_initialization_options # Reads package for version info
|
|
)
|
|
|
|
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
|
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
|
|
|
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
|
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
|
|
|
async with (
|
|
sse_response(request) as response,
|
|
session_manager.create(Session(read_stream_writer)) as session_id,
|
|
):
|
|
session_uri = MESSAGES_API.format(session_id=session_id)
|
|
_LOGGER.debug("Sending SSE endpoint: %s", session_uri)
|
|
await response.send(session_uri, event="endpoint")
|
|
|
|
async def sse_reader() -> None:
|
|
"""Forward MCP server responses to the client."""
|
|
async for message in write_stream_reader:
|
|
_LOGGER.debug("Sending SSE message: %s", message)
|
|
await response.send(
|
|
message.model_dump_json(by_alias=True, exclude_none=True),
|
|
event="message",
|
|
)
|
|
|
|
async with anyio.create_task_group() as tg:
|
|
tg.start_soon(sse_reader)
|
|
await server.run(read_stream, write_stream, options)
|
|
return response
|
|
|
|
|
|
class ModelContextProtocolMessagesView(HomeAssistantView):
|
|
"""Model Context Protocol messages endpoint."""
|
|
|
|
name = f"{DOMAIN}:messages"
|
|
url = MESSAGES_API
|
|
|
|
async def post(
|
|
self,
|
|
request: web.Request,
|
|
session_id: str,
|
|
) -> web.StreamResponse:
|
|
"""Process incoming messages for the Model Context Protocol.
|
|
|
|
The request passes a session ID which is used to identify the original
|
|
SSE connection. This view parses incoming messages from the transport
|
|
layer then writes them to the MCP server stream for the session.
|
|
"""
|
|
hass = request.app[KEY_HASS]
|
|
config_entry = async_get_config_entry(hass)
|
|
|
|
session_manager = config_entry.runtime_data
|
|
if (session := session_manager.get(session_id)) is None:
|
|
_LOGGER.info("Could not find session ID: '%s'", session_id)
|
|
raise HTTPNotFound(text=f"Could not find session ID '{session_id}'")
|
|
|
|
json_data = await request.json()
|
|
try:
|
|
message = types.JSONRPCMessage.model_validate(json_data)
|
|
except ValueError as err:
|
|
_LOGGER.info("Failed to parse message: %s", err)
|
|
raise HTTPBadRequest(text="Could not parse message") from err
|
|
|
|
_LOGGER.debug("Received client message: %s", message)
|
|
await session.read_stream_writer.send(message)
|
|
return web.Response(status=200)
|