core/homeassistant/util/executor.py

103 lines
2.9 KiB
Python

"""Executor util helpers."""
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
import contextlib
import logging
import sys
from threading import Thread
import time
import traceback
from typing import Any
from .thread import async_raise
_LOGGER = logging.getLogger(__name__)
MAX_LOG_ATTEMPTS = 2
_JOIN_ATTEMPTS = 10
EXECUTOR_SHUTDOWN_TIMEOUT = 10
def _log_thread_running_at_shutdown(name: str, ident: int) -> None:
"""Log the stack of a thread that was still running at shutdown."""
frames = sys._current_frames() # noqa: SLF001
stack = frames.get(ident)
formatted_stack = traceback.format_stack(stack)
_LOGGER.warning(
"Thread[%s] is still running at shutdown: %s",
name,
"".join(formatted_stack).strip(),
)
def join_or_interrupt_threads(
threads: set[Thread], timeout: float, log: bool
) -> set[Thread]:
"""Attempt to join or interrupt a set of threads."""
joined = set()
timeout_per_thread = timeout / len(threads)
for thread in threads:
thread.join(timeout=timeout_per_thread)
if not thread.is_alive() or thread.ident is None:
joined.add(thread)
continue
if log:
_log_thread_running_at_shutdown(thread.name, thread.ident)
with contextlib.suppress(SystemError):
# SystemError at this stage is usually a race condition
# where the thread happens to die right before we force
# it to raise the exception
async_raise(thread.ident, SystemExit)
return joined
class InterruptibleThreadPoolExecutor(ThreadPoolExecutor):
"""A ThreadPoolExecutor instance that will not deadlock on shutdown."""
def shutdown(
self, *args: Any, join_threads_or_timeout: bool = True, **kwargs: Any
) -> None:
"""Shutdown with interrupt support added.
By default shutdown will wait for threads to finish up
to the timeout before forcefully stopping them. This can
be disabled by setting `join_threads_or_timeout` to False.
"""
super().shutdown(wait=False, cancel_futures=True)
if join_threads_or_timeout:
self.join_threads_or_timeout()
def join_threads_or_timeout(self) -> None:
"""Join threads or timeout."""
remaining_threads = set(self._threads)
start_time = time.monotonic()
timeout_remaining: float = EXECUTOR_SHUTDOWN_TIMEOUT
attempt = 0
while True:
if not remaining_threads:
return
attempt += 1
remaining_threads -= join_or_interrupt_threads(
remaining_threads,
timeout_remaining / _JOIN_ATTEMPTS,
attempt <= MAX_LOG_ATTEMPTS,
)
timeout_remaining = EXECUTOR_SHUTDOWN_TIMEOUT - (
time.monotonic() - start_time
)
if timeout_remaining <= 0:
return