mirror of https://github.com/home-assistant/core
540 lines
16 KiB
Python
540 lines
16 KiB
Python
"""Advanced timeout handling.
|
|
|
|
Set of helper classes to handle timeouts of tasks with advanced options
|
|
like zones and freezing of timeouts.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import enum
|
|
from types import TracebackType
|
|
from typing import Any, Self
|
|
|
|
from .async_ import run_callback_threadsafe
|
|
|
|
ZONE_GLOBAL = "global"
|
|
|
|
|
|
class _State(enum.Enum):
|
|
"""States of a task."""
|
|
|
|
INIT = "INIT"
|
|
ACTIVE = "ACTIVE"
|
|
TIMEOUT = "TIMEOUT"
|
|
EXIT = "EXIT"
|
|
|
|
|
|
class _GlobalFreezeContext:
|
|
"""Context manager that freezes the global timeout."""
|
|
|
|
def __init__(self, manager: TimeoutManager) -> None:
|
|
"""Initialize internal timeout context manager."""
|
|
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
|
|
self._manager: TimeoutManager = manager
|
|
|
|
async def __aenter__(self) -> Self:
|
|
self._enter()
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> bool | None:
|
|
self._exit()
|
|
return None
|
|
|
|
def __enter__(self) -> Self:
|
|
self._loop.call_soon_threadsafe(self._enter)
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> bool | None:
|
|
self._loop.call_soon_threadsafe(self._exit)
|
|
return None
|
|
|
|
def _enter(self) -> None:
|
|
"""Run freeze."""
|
|
if self._manager.freezes_done:
|
|
# Global reset
|
|
for task in self._manager.global_tasks:
|
|
task.pause()
|
|
|
|
# Zones reset
|
|
for zone in self._manager.zones.values():
|
|
if not zone.freezes_done:
|
|
continue
|
|
zone.pause()
|
|
|
|
self._manager.global_freezes.append(self)
|
|
|
|
def _exit(self) -> None:
|
|
"""Finish freeze."""
|
|
self._manager.global_freezes.remove(self)
|
|
if not self._manager.freezes_done:
|
|
return
|
|
|
|
# Global reset
|
|
for task in self._manager.global_tasks:
|
|
task.reset()
|
|
|
|
# Zones reset
|
|
for zone in self._manager.zones.values():
|
|
if not zone.freezes_done:
|
|
continue
|
|
zone.reset()
|
|
|
|
|
|
class _ZoneFreezeContext:
|
|
"""Context manager that freezes a zone timeout."""
|
|
|
|
def __init__(self, zone: _ZoneTimeoutManager) -> None:
|
|
"""Initialize internal timeout context manager."""
|
|
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
|
|
self._zone: _ZoneTimeoutManager = zone
|
|
|
|
async def __aenter__(self) -> Self:
|
|
self._enter()
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> bool | None:
|
|
self._exit()
|
|
return None
|
|
|
|
def __enter__(self) -> Self:
|
|
self._loop.call_soon_threadsafe(self._enter)
|
|
return self
|
|
|
|
def __exit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> bool | None:
|
|
self._loop.call_soon_threadsafe(self._exit)
|
|
return None
|
|
|
|
def _enter(self) -> None:
|
|
"""Run freeze."""
|
|
if self._zone.freezes_done:
|
|
self._zone.pause()
|
|
self._zone.enter_freeze(self)
|
|
|
|
def _exit(self) -> None:
|
|
"""Finish freeze."""
|
|
self._zone.exit_freeze(self)
|
|
if not self._zone.freezes_done:
|
|
return
|
|
self._zone.reset()
|
|
|
|
|
|
class _GlobalTaskContext:
|
|
"""Context manager that tracks a global task."""
|
|
|
|
def __init__(
|
|
self,
|
|
manager: TimeoutManager,
|
|
task: asyncio.Task[Any],
|
|
timeout: float,
|
|
cool_down: float,
|
|
) -> None:
|
|
"""Initialize internal timeout context manager."""
|
|
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
|
|
self._manager: TimeoutManager = manager
|
|
self._task: asyncio.Task[Any] = task
|
|
self._time_left: float = timeout
|
|
self._expiration_time: float | None = None
|
|
self._timeout_handler: asyncio.Handle | None = None
|
|
self._on_wait_task: asyncio.Task | None = None
|
|
self._wait_zone: asyncio.Event = asyncio.Event()
|
|
self._state: _State = _State.INIT
|
|
self._cool_down: float = cool_down
|
|
self._cancelling = 0
|
|
|
|
async def __aenter__(self) -> Self:
|
|
self._manager.global_tasks.append(self)
|
|
self._start_timer()
|
|
self._state = _State.ACTIVE
|
|
# Remember if the task was already cancelling
|
|
# so when we __aexit__ we can decide if we should
|
|
# raise asyncio.TimeoutError or let the cancellation propagate
|
|
self._cancelling = self._task.cancelling()
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> bool | None:
|
|
self._stop_timer()
|
|
self._manager.global_tasks.remove(self)
|
|
|
|
# Timeout on exit
|
|
if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT:
|
|
# The timeout was hit, and the task was cancelled
|
|
# so we need to uncancel the task since the cancellation
|
|
# should not leak out of the context manager
|
|
if self._task.uncancel() > self._cancelling:
|
|
# If the task was already cancelling don't raise
|
|
# asyncio.TimeoutError and instead return None
|
|
# to allow the cancellation to propagate
|
|
return None
|
|
raise TimeoutError
|
|
|
|
self._state = _State.EXIT
|
|
self._wait_zone.set()
|
|
return None
|
|
|
|
@property
|
|
def state(self) -> _State:
|
|
"""Return state of the Global task."""
|
|
return self._state
|
|
|
|
def zones_done_signal(self) -> None:
|
|
"""Signal that all zones are done."""
|
|
self._wait_zone.set()
|
|
|
|
def _start_timer(self) -> None:
|
|
"""Start timeout handler."""
|
|
if self._timeout_handler:
|
|
return
|
|
|
|
self._expiration_time = self._loop.time() + self._time_left
|
|
self._timeout_handler = self._loop.call_at(
|
|
self._expiration_time, self._on_timeout
|
|
)
|
|
|
|
def _stop_timer(self) -> None:
|
|
"""Stop zone timer."""
|
|
if self._timeout_handler is None:
|
|
return
|
|
|
|
self._timeout_handler.cancel()
|
|
self._timeout_handler = None
|
|
# Calculate new timeout
|
|
assert self._expiration_time
|
|
self._time_left = self._expiration_time - self._loop.time()
|
|
|
|
def _on_timeout(self) -> None:
|
|
"""Process timeout."""
|
|
self._state = _State.TIMEOUT
|
|
self._timeout_handler = None
|
|
|
|
# Reset timer if zones are running
|
|
if not self._manager.zones_done:
|
|
self._on_wait_task = asyncio.create_task(self._on_wait())
|
|
else:
|
|
self._cancel_task()
|
|
|
|
def _cancel_task(self) -> None:
|
|
"""Cancel own task."""
|
|
if self._task.done():
|
|
return
|
|
self._task.cancel("Global task timeout")
|
|
|
|
def pause(self) -> None:
|
|
"""Pause timers while it freeze."""
|
|
self._stop_timer()
|
|
|
|
def reset(self) -> None:
|
|
"""Reset timer after freeze."""
|
|
self._start_timer()
|
|
|
|
async def _on_wait(self) -> None:
|
|
"""Wait until zones are done."""
|
|
await self._wait_zone.wait()
|
|
await asyncio.sleep(self._cool_down) # Allow context switch
|
|
self._on_wait_task = None
|
|
if self.state != _State.TIMEOUT:
|
|
return
|
|
self._cancel_task()
|
|
|
|
|
|
class _ZoneTaskContext:
|
|
"""Context manager that tracks an active task for a zone."""
|
|
|
|
def __init__(
|
|
self,
|
|
zone: _ZoneTimeoutManager,
|
|
task: asyncio.Task[Any],
|
|
timeout: float,
|
|
) -> None:
|
|
"""Initialize internal timeout context manager."""
|
|
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
|
|
self._zone: _ZoneTimeoutManager = zone
|
|
self._task: asyncio.Task[Any] = task
|
|
self._state: _State = _State.INIT
|
|
self._time_left: float = timeout
|
|
self._expiration_time: float | None = None
|
|
self._timeout_handler: asyncio.Handle | None = None
|
|
self._cancelling = 0
|
|
|
|
@property
|
|
def state(self) -> _State:
|
|
"""Return state of the Zone task."""
|
|
return self._state
|
|
|
|
async def __aenter__(self) -> Self:
|
|
self._zone.enter_task(self)
|
|
self._state = _State.ACTIVE
|
|
|
|
# Zone is on freeze
|
|
if self._zone.freezes_done:
|
|
self._start_timer()
|
|
|
|
# Remember if the task was already cancelling
|
|
# so when we __aexit__ we can decide if we should
|
|
# raise asyncio.TimeoutError or let the cancellation propagate
|
|
self._cancelling = self._task.cancelling()
|
|
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> bool | None:
|
|
self._zone.exit_task(self)
|
|
self._stop_timer()
|
|
|
|
# Timeout on exit
|
|
if exc_type is asyncio.CancelledError and self.state is _State.TIMEOUT:
|
|
# The timeout was hit, and the task was cancelled
|
|
# so we need to uncancel the task since the cancellation
|
|
# should not leak out of the context manager
|
|
if self._task.uncancel() > self._cancelling:
|
|
# If the task was already cancelling don't raise
|
|
# asyncio.TimeoutError and instead return None
|
|
# to allow the cancellation to propagate
|
|
return None
|
|
raise TimeoutError
|
|
|
|
self._state = _State.EXIT
|
|
return None
|
|
|
|
def _start_timer(self) -> None:
|
|
"""Start timeout handler."""
|
|
if self._timeout_handler:
|
|
return
|
|
|
|
self._expiration_time = self._loop.time() + self._time_left
|
|
self._timeout_handler = self._loop.call_at(
|
|
self._expiration_time, self._on_timeout
|
|
)
|
|
|
|
def _stop_timer(self) -> None:
|
|
"""Stop zone timer."""
|
|
if self._timeout_handler is None:
|
|
return
|
|
|
|
self._timeout_handler.cancel()
|
|
self._timeout_handler = None
|
|
# Calculate new timeout
|
|
assert self._expiration_time
|
|
self._time_left = self._expiration_time - self._loop.time()
|
|
|
|
def _on_timeout(self) -> None:
|
|
"""Process timeout."""
|
|
self._state = _State.TIMEOUT
|
|
self._timeout_handler = None
|
|
|
|
# Timeout
|
|
if self._task.done():
|
|
return
|
|
self._task.cancel("Zone timeout")
|
|
|
|
def pause(self) -> None:
|
|
"""Pause timers while it freeze."""
|
|
self._stop_timer()
|
|
|
|
def reset(self) -> None:
|
|
"""Reset timer after freeze."""
|
|
self._start_timer()
|
|
|
|
|
|
class _ZoneTimeoutManager:
|
|
"""Manage the timeouts for a zone."""
|
|
|
|
def __init__(self, manager: TimeoutManager, zone: str) -> None:
|
|
"""Initialize internal timeout context manager."""
|
|
self._manager: TimeoutManager = manager
|
|
self._zone: str = zone
|
|
self._tasks: list[_ZoneTaskContext] = []
|
|
self._freezes: list[_ZoneFreezeContext] = []
|
|
|
|
def __repr__(self) -> str:
|
|
"""Representation of a zone."""
|
|
return f"<{self.name}: {len(self._tasks)} / {len(self._freezes)}>"
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""Return Zone name."""
|
|
return self._zone
|
|
|
|
@property
|
|
def active(self) -> bool:
|
|
"""Return True if zone is active."""
|
|
return len(self._tasks) > 0 or len(self._freezes) > 0
|
|
|
|
@property
|
|
def freezes_done(self) -> bool:
|
|
"""Return True if all freeze are done."""
|
|
return len(self._freezes) == 0 and self._manager.freezes_done
|
|
|
|
def enter_task(self, task: _ZoneTaskContext) -> None:
|
|
"""Start into new Task."""
|
|
self._tasks.append(task)
|
|
|
|
def exit_task(self, task: _ZoneTaskContext) -> None:
|
|
"""Exit a running Task."""
|
|
self._tasks.remove(task)
|
|
|
|
# On latest listener
|
|
if not self.active:
|
|
self._manager.drop_zone(self.name)
|
|
|
|
def enter_freeze(self, freeze: _ZoneFreezeContext) -> None:
|
|
"""Start into new freeze."""
|
|
self._freezes.append(freeze)
|
|
|
|
def exit_freeze(self, freeze: _ZoneFreezeContext) -> None:
|
|
"""Exit a running Freeze."""
|
|
self._freezes.remove(freeze)
|
|
|
|
# On latest listener
|
|
if not self.active:
|
|
self._manager.drop_zone(self.name)
|
|
|
|
def pause(self) -> None:
|
|
"""Stop timers while it freeze."""
|
|
if not self.active:
|
|
return
|
|
|
|
# Forward pause
|
|
for task in self._tasks:
|
|
task.pause()
|
|
|
|
def reset(self) -> None:
|
|
"""Reset timer after freeze."""
|
|
if not self.active:
|
|
return
|
|
|
|
# Forward reset
|
|
for task in self._tasks:
|
|
task.reset()
|
|
|
|
|
|
class TimeoutManager:
|
|
"""Class to manage timeouts over different zones.
|
|
|
|
Manages both global and zone based timeouts.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
"""Initialize TimeoutManager."""
|
|
self._loop: asyncio.AbstractEventLoop = asyncio.get_running_loop()
|
|
self._zones: dict[str, _ZoneTimeoutManager] = {}
|
|
self._globals: list[_GlobalTaskContext] = []
|
|
self._freezes: list[_GlobalFreezeContext] = []
|
|
|
|
@property
|
|
def zones_done(self) -> bool:
|
|
"""Return True if all zones are finished."""
|
|
return not bool(self._zones)
|
|
|
|
@property
|
|
def freezes_done(self) -> bool:
|
|
"""Return True if all freezes are finished."""
|
|
return not self._freezes
|
|
|
|
@property
|
|
def zones(self) -> dict[str, _ZoneTimeoutManager]:
|
|
"""Return all Zones."""
|
|
return self._zones
|
|
|
|
@property
|
|
def global_tasks(self) -> list[_GlobalTaskContext]:
|
|
"""Return all global Tasks."""
|
|
return self._globals
|
|
|
|
@property
|
|
def global_freezes(self) -> list[_GlobalFreezeContext]:
|
|
"""Return all global Freezes."""
|
|
return self._freezes
|
|
|
|
def drop_zone(self, zone_name: str) -> None:
|
|
"""Drop a zone out of scope."""
|
|
self._zones.pop(zone_name, None)
|
|
if self._zones:
|
|
return
|
|
|
|
# Signal Global task, all zones are done
|
|
for task in self._globals:
|
|
task.zones_done_signal()
|
|
|
|
def async_timeout(
|
|
self, timeout: float, zone_name: str = ZONE_GLOBAL, cool_down: float = 0
|
|
) -> _ZoneTaskContext | _GlobalTaskContext:
|
|
"""Timeout based on a zone.
|
|
|
|
For using as Async Context Manager.
|
|
"""
|
|
current_task: asyncio.Task[Any] | None = asyncio.current_task()
|
|
assert current_task
|
|
|
|
# Global Zone
|
|
if zone_name == ZONE_GLOBAL:
|
|
return _GlobalTaskContext(self, current_task, timeout, cool_down)
|
|
|
|
# Zone Handling
|
|
if zone_name in self.zones:
|
|
zone: _ZoneTimeoutManager = self.zones[zone_name]
|
|
else:
|
|
self.zones[zone_name] = zone = _ZoneTimeoutManager(self, zone_name)
|
|
|
|
# Create Task
|
|
return _ZoneTaskContext(zone, current_task, timeout)
|
|
|
|
def async_freeze(
|
|
self, zone_name: str = ZONE_GLOBAL
|
|
) -> _ZoneFreezeContext | _GlobalFreezeContext:
|
|
"""Freeze all timer until job is done.
|
|
|
|
For using as Async Context Manager.
|
|
"""
|
|
# Global Freeze
|
|
if zone_name == ZONE_GLOBAL:
|
|
return _GlobalFreezeContext(self)
|
|
|
|
# Zone Freeze
|
|
if zone_name in self.zones:
|
|
zone: _ZoneTimeoutManager = self.zones[zone_name]
|
|
else:
|
|
self.zones[zone_name] = zone = _ZoneTimeoutManager(self, zone_name)
|
|
|
|
return _ZoneFreezeContext(zone)
|
|
|
|
def freeze(
|
|
self, zone_name: str = ZONE_GLOBAL
|
|
) -> _ZoneFreezeContext | _GlobalFreezeContext:
|
|
"""Freeze all timer until job is done.
|
|
|
|
For using as Context Manager.
|
|
"""
|
|
return run_callback_threadsafe(
|
|
self._loop, self.async_freeze, zone_name
|
|
).result()
|