mirror of https://github.com/home-assistant/core
193 lines
6.9 KiB
Python
193 lines
6.9 KiB
Python
"""Fixtures for the recorder component tests."""
|
|
|
|
from collections.abc import Callable, Generator
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
import threading
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy.orm.session import Session
|
|
|
|
from homeassistant.components import recorder
|
|
from homeassistant.components.recorder import db_schema
|
|
from homeassistant.components.recorder.util import session_scope
|
|
from homeassistant.core import HomeAssistant
|
|
|
|
|
|
def pytest_configure(config):
|
|
"""Add custom skip_on_db_engine marker."""
|
|
config.addinivalue_line(
|
|
"markers",
|
|
"skip_on_db_engine(engine): mark test to run only on named DB engine(s)",
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def skip_by_db_engine(request: pytest.FixtureRequest, recorder_db_url: str) -> None:
|
|
"""Fixture to skip tests on unsupported DB engines.
|
|
|
|
Mark the test with @pytest.mark.skip_on_db_engine("mysql") to skip on mysql, or
|
|
@pytest.mark.skip_on_db_engine(["mysql", "sqlite"]) to skip on mysql and sqlite.
|
|
"""
|
|
if request.node.get_closest_marker("skip_on_db_engine"):
|
|
skip_on_db_engine = request.node.get_closest_marker("skip_on_db_engine").args[0]
|
|
if isinstance(skip_on_db_engine, str):
|
|
skip_on_db_engine = [skip_on_db_engine]
|
|
db_engine = recorder_db_url.partition("://")[0]
|
|
if db_engine in skip_on_db_engine:
|
|
pytest.skip(f"skipped for DB engine: {db_engine}")
|
|
|
|
|
|
@pytest.fixture
|
|
def recorder_dialect_name(hass: HomeAssistant, db_engine: str) -> Generator[None]:
|
|
"""Patch the recorder dialect."""
|
|
if instance := hass.data.get(recorder.DATA_INSTANCE):
|
|
instance.__dict__.pop("dialect_name", None)
|
|
with patch.object(instance, "_dialect_name", db_engine):
|
|
yield
|
|
instance.__dict__.pop("dialect_name", None)
|
|
else:
|
|
with patch(
|
|
"homeassistant.components.recorder.Recorder.dialect_name", db_engine
|
|
):
|
|
yield
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class InstrumentedMigration:
|
|
"""Container to aid controlling migration progress."""
|
|
|
|
live_migration_done: threading.Event
|
|
live_migration_done_stall: threading.Event
|
|
migration_stall: threading.Event
|
|
migration_started: threading.Event
|
|
migration_version: int | None
|
|
non_live_migration_done: threading.Event
|
|
non_live_migration_done_stall: threading.Event
|
|
apply_update_mock: Mock
|
|
stall_on_schema_version: int | None
|
|
apply_update_stalled: threading.Event
|
|
apply_update_version: int | None
|
|
|
|
|
|
@pytest.fixture(name="instrument_migration")
|
|
def instrument_migration_fixture(
|
|
hass: HomeAssistant,
|
|
) -> Generator[InstrumentedMigration]:
|
|
"""Instrument recorder migration."""
|
|
with instrument_migration(hass) as instrumented_migration:
|
|
yield instrumented_migration
|
|
|
|
|
|
@contextmanager
|
|
def instrument_migration(
|
|
hass: HomeAssistant,
|
|
) -> Generator[InstrumentedMigration]:
|
|
"""Instrument recorder migration."""
|
|
|
|
real_migrate_schema_live = recorder.migration.migrate_schema_live
|
|
real_migrate_schema_non_live = recorder.migration.migrate_schema_non_live
|
|
real_apply_update = recorder.migration._apply_update
|
|
|
|
def _instrument_migrate_schema_live(real_func, *args):
|
|
"""Control migration progress and check results."""
|
|
return _instrument_migrate_schema(
|
|
real_func,
|
|
args,
|
|
instrumented_migration.live_migration_done,
|
|
instrumented_migration.live_migration_done_stall,
|
|
)
|
|
|
|
def _instrument_migrate_schema_non_live(real_func, *args):
|
|
"""Control migration progress and check results."""
|
|
return _instrument_migrate_schema(
|
|
real_func,
|
|
args,
|
|
instrumented_migration.non_live_migration_done,
|
|
instrumented_migration.non_live_migration_done_stall,
|
|
)
|
|
|
|
def _instrument_migrate_schema(
|
|
real_func,
|
|
args,
|
|
migration_done: threading.Event,
|
|
migration_done_stall: threading.Event,
|
|
):
|
|
"""Control migration progress and check results."""
|
|
instrumented_migration.migration_started.set()
|
|
|
|
try:
|
|
migration_result = real_func(*args)
|
|
except Exception:
|
|
migration_done.set()
|
|
migration_done_stall.wait()
|
|
raise
|
|
|
|
# Check and report the outcome of the migration; if migration fails
|
|
# the recorder will silently create a new database.
|
|
with session_scope(hass=hass, read_only=True) as session:
|
|
res = (
|
|
session.query(db_schema.SchemaChanges)
|
|
.order_by(db_schema.SchemaChanges.change_id.desc())
|
|
.first()
|
|
)
|
|
instrumented_migration.migration_version = res.schema_version
|
|
migration_done.set()
|
|
migration_done_stall.wait()
|
|
return migration_result
|
|
|
|
def _instrument_apply_update(
|
|
instance: recorder.Recorder,
|
|
hass: HomeAssistant,
|
|
engine: Engine,
|
|
session_maker: Callable[[], Session],
|
|
new_version: int,
|
|
old_version: int,
|
|
):
|
|
"""Control migration progress."""
|
|
instrumented_migration.apply_update_version = new_version
|
|
stall_version = instrumented_migration.stall_on_schema_version
|
|
if stall_version is None or stall_version == new_version:
|
|
instrumented_migration.apply_update_stalled.set()
|
|
instrumented_migration.migration_stall.wait()
|
|
real_apply_update(
|
|
instance, hass, engine, session_maker, new_version, old_version
|
|
)
|
|
|
|
with (
|
|
patch(
|
|
"homeassistant.components.recorder.migration.migrate_schema_live",
|
|
wraps=partial(_instrument_migrate_schema_live, real_migrate_schema_live),
|
|
),
|
|
patch(
|
|
"homeassistant.components.recorder.migration.migrate_schema_non_live",
|
|
wraps=partial(
|
|
_instrument_migrate_schema_non_live, real_migrate_schema_non_live
|
|
),
|
|
),
|
|
patch(
|
|
"homeassistant.components.recorder.migration._apply_update",
|
|
wraps=_instrument_apply_update,
|
|
) as apply_update_mock,
|
|
):
|
|
instrumented_migration = InstrumentedMigration(
|
|
live_migration_done=threading.Event(),
|
|
live_migration_done_stall=threading.Event(),
|
|
migration_stall=threading.Event(),
|
|
migration_started=threading.Event(),
|
|
migration_version=None,
|
|
non_live_migration_done=threading.Event(),
|
|
non_live_migration_done_stall=threading.Event(),
|
|
apply_update_mock=apply_update_mock,
|
|
stall_on_schema_version=None,
|
|
apply_update_stalled=threading.Event(),
|
|
apply_update_version=None,
|
|
)
|
|
|
|
instrumented_migration.live_migration_done_stall.set()
|
|
instrumented_migration.non_live_migration_done_stall.set()
|
|
yield instrumented_migration
|