core/homeassistant/components/dlna_dmr/data.py

128 lines
4.9 KiB
Python

"""Data used by this integration."""
from __future__ import annotations
import asyncio
from collections import defaultdict
from typing import NamedTuple, cast
from async_upnp_client.aiohttp import AiohttpNotifyServer, AiohttpSessionRequester
from async_upnp_client.client import UpnpRequester
from async_upnp_client.client_factory import UpnpFactory
from async_upnp_client.event_handler import UpnpEventHandler
from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant
from homeassistant.helpers import aiohttp_client
from .const import DOMAIN, LOGGER
class EventListenAddr(NamedTuple):
"""Unique identifier for an event listener."""
host: str | None # Specific local IP(v6) address for listening on
port: int # Listening port, 0 means use an ephemeral port
callback_url: str | None
class DlnaDmrData:
"""Storage class for domain global data."""
lock: asyncio.Lock
requester: UpnpRequester
upnp_factory: UpnpFactory
event_notifiers: dict[EventListenAddr, AiohttpNotifyServer]
event_notifier_refs: defaultdict[EventListenAddr, int]
stop_listener_remove: CALLBACK_TYPE | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize global data."""
self.lock = asyncio.Lock()
session = aiohttp_client.async_get_clientsession(hass, verify_ssl=False)
self.requester = AiohttpSessionRequester(session, with_sleep=True)
self.upnp_factory = UpnpFactory(self.requester, non_strict=True)
self.event_notifiers = {}
self.event_notifier_refs = defaultdict(int)
async def async_cleanup_event_notifiers(self, event: Event) -> None:
"""Clean up resources when Home Assistant is stopped."""
LOGGER.debug("Cleaning resources in DlnaDmrData")
async with self.lock:
tasks = (
server.async_stop_server() for server in self.event_notifiers.values()
)
asyncio.gather(*tasks)
self.event_notifiers = {}
self.event_notifier_refs = defaultdict(int)
async def async_get_event_notifier(
self, listen_addr: EventListenAddr, hass: HomeAssistant
) -> UpnpEventHandler:
"""Return existing event notifier for the listen_addr, or create one.
Only one event notify server is kept for each listen_addr. Must call
async_release_event_notifier when done to cleanup resources.
"""
LOGGER.debug("Getting event handler for %s", listen_addr)
async with self.lock:
# Stop all servers when HA shuts down, to release resources on devices
if not self.stop_listener_remove:
self.stop_listener_remove = hass.bus.async_listen_once(
EVENT_HOMEASSISTANT_STOP, self.async_cleanup_event_notifiers
)
# Always increment the reference counter, for existing or new event handlers
self.event_notifier_refs[listen_addr] += 1
# Return an existing event handler if we can
if listen_addr in self.event_notifiers:
return self.event_notifiers[listen_addr].event_handler
# Start event handler
source = (listen_addr.host or "0.0.0.0", listen_addr.port)
server = AiohttpNotifyServer(
requester=self.requester,
source=source,
callback_url=listen_addr.callback_url,
loop=hass.loop,
)
await server.async_start_server()
LOGGER.debug("Started event handler at %s", server.callback_url)
self.event_notifiers[listen_addr] = server
return server.event_handler
async def async_release_event_notifier(self, listen_addr: EventListenAddr) -> None:
"""Indicate that the event notifier for listen_addr is not used anymore.
This is called once by each caller of async_get_event_notifier, and will
stop the listening server when all users are done.
"""
async with self.lock:
assert self.event_notifier_refs[listen_addr] > 0
self.event_notifier_refs[listen_addr] -= 1
# Shutdown the server when it has no more users
if self.event_notifier_refs[listen_addr] == 0:
server = self.event_notifiers.pop(listen_addr)
await server.async_stop_server()
# Remove the cleanup listener when there's nothing left to cleanup
if not self.event_notifiers:
assert self.stop_listener_remove is not None
self.stop_listener_remove()
self.stop_listener_remove = None
def get_domain_data(hass: HomeAssistant) -> DlnaDmrData:
"""Obtain this integration's domain data, creating it if needed."""
if DOMAIN in hass.data:
return cast(DlnaDmrData, hass.data[DOMAIN])
data = DlnaDmrData(hass)
hass.data[DOMAIN] = data
return data