mirror of https://github.com/home-assistant/core
588 lines
20 KiB
Python
588 lines
20 KiB
Python
"""Config Flow using OAuth2.
|
|
|
|
This module exists of the following parts:
|
|
- OAuth2 config flow which supports multiple OAuth2 implementations
|
|
- OAuth2 implementation that works with local provided client ID/secret
|
|
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from abc import ABC, ABCMeta, abstractmethod
|
|
import asyncio
|
|
from asyncio import Lock
|
|
from collections.abc import Awaitable, Callable
|
|
from http import HTTPStatus
|
|
from json import JSONDecodeError
|
|
import logging
|
|
import secrets
|
|
import time
|
|
from typing import Any, cast
|
|
|
|
from aiohttp import ClientError, ClientResponseError, client, web
|
|
import jwt
|
|
import voluptuous as vol
|
|
from yarl import URL
|
|
|
|
from homeassistant import config_entries
|
|
from homeassistant.components import http
|
|
from homeassistant.core import HomeAssistant, callback
|
|
from homeassistant.loader import async_get_application_credentials
|
|
from homeassistant.util.hass_dict import HassKey
|
|
|
|
from .aiohttp_client import async_get_clientsession
|
|
from .network import NoURLAvailableError
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
DATA_JWT_SECRET = "oauth2_jwt_secret"
|
|
DATA_IMPLEMENTATIONS: HassKey[dict[str, dict[str, AbstractOAuth2Implementation]]] = (
|
|
HassKey("oauth2_impl")
|
|
)
|
|
DATA_PROVIDERS: HassKey[
|
|
dict[
|
|
str,
|
|
Callable[[HomeAssistant, str], Awaitable[list[AbstractOAuth2Implementation]]],
|
|
]
|
|
] = HassKey("oauth2_providers")
|
|
AUTH_CALLBACK_PATH = "/auth/external/callback"
|
|
HEADER_FRONTEND_BASE = "HA-Frontend-Base"
|
|
MY_AUTH_CALLBACK_PATH = "https://my.home-assistant.io/redirect/oauth"
|
|
|
|
CLOCK_OUT_OF_SYNC_MAX_SEC = 20
|
|
|
|
OAUTH_AUTHORIZE_URL_TIMEOUT_SEC = 30
|
|
OAUTH_TOKEN_TIMEOUT_SEC = 30
|
|
|
|
|
|
class AbstractOAuth2Implementation(ABC):
|
|
"""Base class to abstract OAuth2 authentication."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def name(self) -> str:
|
|
"""Name of the implementation."""
|
|
|
|
@property
|
|
@abstractmethod
|
|
def domain(self) -> str:
|
|
"""Domain that is providing the implementation."""
|
|
|
|
@abstractmethod
|
|
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
|
"""Generate a url for the user to authorize.
|
|
|
|
This step is called when a config flow is initialized. It should redirect the
|
|
user to the vendor website where they can authorize Home Assistant.
|
|
|
|
The implementation is responsible to get notified when the user is authorized
|
|
and pass this to the specified config flow. Do as little work as possible once
|
|
notified. You can do the work inside async_resolve_external_data. This will
|
|
give the best UX.
|
|
|
|
Pass external data in with:
|
|
|
|
await hass.config_entries.flow.async_configure(
|
|
flow_id=flow_id, user_input={'code': 'abcd', 'state': … }
|
|
|
|
)
|
|
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def async_resolve_external_data(self, external_data: Any) -> dict:
|
|
"""Resolve external data to tokens.
|
|
|
|
Turn the data that the implementation passed to the config flow as external
|
|
step data into tokens. These tokens will be stored as 'token' in the
|
|
config entry data.
|
|
"""
|
|
|
|
async def async_refresh_token(self, token: dict) -> dict:
|
|
"""Refresh a token and update expires info."""
|
|
new_token = await self._async_refresh_token(token)
|
|
# Force int for non-compliant oauth2 providers
|
|
new_token["expires_in"] = int(new_token["expires_in"])
|
|
new_token["expires_at"] = time.time() + new_token["expires_in"]
|
|
return new_token
|
|
|
|
@abstractmethod
|
|
async def _async_refresh_token(self, token: dict) -> dict:
|
|
"""Refresh a token."""
|
|
|
|
|
|
class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
|
"""Local OAuth2 implementation."""
|
|
|
|
def __init__(
|
|
self,
|
|
hass: HomeAssistant,
|
|
domain: str,
|
|
client_id: str,
|
|
client_secret: str,
|
|
authorize_url: str,
|
|
token_url: str,
|
|
) -> None:
|
|
"""Initialize local auth implementation."""
|
|
self.hass = hass
|
|
self._domain = domain
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
self.authorize_url = authorize_url
|
|
self.token_url = token_url
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""Name of the implementation."""
|
|
return "Configuration.yaml"
|
|
|
|
@property
|
|
def domain(self) -> str:
|
|
"""Domain providing the implementation."""
|
|
return self._domain
|
|
|
|
@property
|
|
def redirect_uri(self) -> str:
|
|
"""Return the redirect uri."""
|
|
if "my" in self.hass.config.components:
|
|
return MY_AUTH_CALLBACK_PATH
|
|
|
|
if (req := http.current_request.get()) is None:
|
|
raise RuntimeError("No current request in context")
|
|
|
|
if (ha_host := req.headers.get(HEADER_FRONTEND_BASE)) is None:
|
|
raise RuntimeError("No header in request")
|
|
|
|
return f"{ha_host}{AUTH_CALLBACK_PATH}"
|
|
|
|
@property
|
|
def extra_authorize_data(self) -> dict:
|
|
"""Extra data that needs to be appended to the authorize url."""
|
|
return {}
|
|
|
|
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
|
"""Generate a url for the user to authorize."""
|
|
redirect_uri = self.redirect_uri
|
|
return str(
|
|
URL(self.authorize_url)
|
|
.with_query(
|
|
{
|
|
"response_type": "code",
|
|
"client_id": self.client_id,
|
|
"redirect_uri": redirect_uri,
|
|
"state": _encode_jwt(
|
|
self.hass, {"flow_id": flow_id, "redirect_uri": redirect_uri}
|
|
),
|
|
}
|
|
)
|
|
.update_query(self.extra_authorize_data)
|
|
)
|
|
|
|
async def async_resolve_external_data(self, external_data: Any) -> dict:
|
|
"""Resolve the authorization code to tokens."""
|
|
return await self._token_request(
|
|
{
|
|
"grant_type": "authorization_code",
|
|
"code": external_data["code"],
|
|
"redirect_uri": external_data["state"]["redirect_uri"],
|
|
}
|
|
)
|
|
|
|
async def _async_refresh_token(self, token: dict) -> dict:
|
|
"""Refresh tokens."""
|
|
new_token = await self._token_request(
|
|
{
|
|
"grant_type": "refresh_token",
|
|
"client_id": self.client_id,
|
|
"refresh_token": token["refresh_token"],
|
|
}
|
|
)
|
|
return {**token, **new_token}
|
|
|
|
async def _token_request(self, data: dict) -> dict:
|
|
"""Make a token request."""
|
|
session = async_get_clientsession(self.hass)
|
|
|
|
data["client_id"] = self.client_id
|
|
|
|
if self.client_secret is not None:
|
|
data["client_secret"] = self.client_secret
|
|
|
|
_LOGGER.debug("Sending token request to %s", self.token_url)
|
|
resp = await session.post(self.token_url, data=data)
|
|
if resp.status >= 400:
|
|
try:
|
|
error_response = await resp.json()
|
|
except (ClientError, JSONDecodeError):
|
|
error_response = {}
|
|
error_code = error_response.get("error", "unknown")
|
|
error_description = error_response.get("error_description", "unknown error")
|
|
_LOGGER.error(
|
|
"Token request for %s failed (%s): %s",
|
|
self.domain,
|
|
error_code,
|
|
error_description,
|
|
)
|
|
resp.raise_for_status()
|
|
return cast(dict, await resp.json())
|
|
|
|
|
|
class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
|
|
"""Handle a config flow."""
|
|
|
|
DOMAIN = ""
|
|
|
|
VERSION = 1
|
|
|
|
def __init__(self) -> None:
|
|
"""Instantiate config flow."""
|
|
if self.DOMAIN == "":
|
|
raise TypeError(
|
|
f"Can't instantiate class {self.__class__.__name__} without DOMAIN"
|
|
" being set"
|
|
)
|
|
|
|
self.external_data: Any = None
|
|
self.flow_impl: AbstractOAuth2Implementation = None # type: ignore[assignment]
|
|
|
|
@property
|
|
@abstractmethod
|
|
def logger(self) -> logging.Logger:
|
|
"""Return logger."""
|
|
|
|
@property
|
|
def extra_authorize_data(self) -> dict:
|
|
"""Extra data that needs to be appended to the authorize url."""
|
|
return {}
|
|
|
|
async def async_generate_authorize_url(self) -> str:
|
|
"""Generate a url for the user to authorize."""
|
|
url = await self.flow_impl.async_generate_authorize_url(self.flow_id)
|
|
return str(URL(url).update_query(self.extra_authorize_data))
|
|
|
|
async def async_step_pick_implementation(
|
|
self, user_input: dict | None = None
|
|
) -> config_entries.ConfigFlowResult:
|
|
"""Handle a flow start."""
|
|
implementations = await async_get_implementations(self.hass, self.DOMAIN)
|
|
|
|
if user_input is not None:
|
|
self.flow_impl = implementations[user_input["implementation"]]
|
|
return await self.async_step_auth()
|
|
|
|
if not implementations:
|
|
if self.DOMAIN in await async_get_application_credentials(self.hass):
|
|
return self.async_abort(reason="missing_credentials")
|
|
return self.async_abort(reason="missing_configuration")
|
|
|
|
req = http.current_request.get()
|
|
if len(implementations) == 1 and req is not None:
|
|
# Pick first implementation if we have only one, but only
|
|
# if this is triggered by a user interaction (request).
|
|
self.flow_impl = list(implementations.values())[0]
|
|
return await self.async_step_auth()
|
|
|
|
return self.async_show_form(
|
|
step_id="pick_implementation",
|
|
data_schema=vol.Schema(
|
|
{
|
|
vol.Required(
|
|
"implementation", default=list(implementations)[0]
|
|
): vol.In({key: impl.name for key, impl in implementations.items()})
|
|
}
|
|
),
|
|
)
|
|
|
|
async def async_step_auth(
|
|
self, user_input: dict[str, Any] | None = None
|
|
) -> config_entries.ConfigFlowResult:
|
|
"""Create an entry for auth."""
|
|
# Flow has been triggered by external data
|
|
if user_input is not None:
|
|
self.external_data = user_input
|
|
next_step = "authorize_rejected" if "error" in user_input else "creation"
|
|
return self.async_external_step_done(next_step_id=next_step)
|
|
|
|
try:
|
|
async with asyncio.timeout(OAUTH_AUTHORIZE_URL_TIMEOUT_SEC):
|
|
url = await self.async_generate_authorize_url()
|
|
except TimeoutError as err:
|
|
_LOGGER.error("Timeout generating authorize url: %s", err)
|
|
return self.async_abort(reason="authorize_url_timeout")
|
|
except NoURLAvailableError:
|
|
return self.async_abort(
|
|
reason="no_url_available",
|
|
description_placeholders={
|
|
"docs_url": (
|
|
"https://www.home-assistant.io/more-info/no-url-available"
|
|
)
|
|
},
|
|
)
|
|
|
|
return self.async_external_step(step_id="auth", url=url)
|
|
|
|
async def async_step_creation(
|
|
self, user_input: dict[str, Any] | None = None
|
|
) -> config_entries.ConfigFlowResult:
|
|
"""Create config entry from external data."""
|
|
_LOGGER.debug("Creating config entry from external data")
|
|
|
|
try:
|
|
async with asyncio.timeout(OAUTH_TOKEN_TIMEOUT_SEC):
|
|
token = await self.flow_impl.async_resolve_external_data(
|
|
self.external_data
|
|
)
|
|
except TimeoutError as err:
|
|
_LOGGER.error("Timeout resolving OAuth token: %s", err)
|
|
return self.async_abort(reason="oauth_timeout")
|
|
except (ClientResponseError, ClientError) as err:
|
|
_LOGGER.error("Error resolving OAuth token: %s", err)
|
|
if (
|
|
isinstance(err, ClientResponseError)
|
|
and err.status == HTTPStatus.UNAUTHORIZED
|
|
):
|
|
return self.async_abort(reason="oauth_unauthorized")
|
|
return self.async_abort(reason="oauth_failed")
|
|
|
|
if "expires_in" not in token:
|
|
_LOGGER.warning("Invalid token: %s", token)
|
|
return self.async_abort(reason="oauth_error")
|
|
|
|
# Force int for non-compliant oauth2 providers
|
|
try:
|
|
token["expires_in"] = int(token["expires_in"])
|
|
except ValueError as err:
|
|
_LOGGER.warning("Error converting expires_in to int: %s", err)
|
|
return self.async_abort(reason="oauth_error")
|
|
token["expires_at"] = time.time() + token["expires_in"]
|
|
|
|
self.logger.info("Successfully authenticated")
|
|
|
|
return await self.async_oauth_create_entry(
|
|
{"auth_implementation": self.flow_impl.domain, "token": token}
|
|
)
|
|
|
|
async def async_step_authorize_rejected(
|
|
self, data: None = None
|
|
) -> config_entries.ConfigFlowResult:
|
|
"""Step to handle flow rejection."""
|
|
return self.async_abort(
|
|
reason="user_rejected_authorize",
|
|
description_placeholders={"error": self.external_data["error"]},
|
|
)
|
|
|
|
async def async_oauth_create_entry(
|
|
self, data: dict
|
|
) -> config_entries.ConfigFlowResult:
|
|
"""Create an entry for the flow.
|
|
|
|
Ok to override if you want to fetch extra info or even add another step.
|
|
"""
|
|
return self.async_create_entry(title=self.flow_impl.name, data=data)
|
|
|
|
async def async_step_user(
|
|
self, user_input: dict[str, Any] | None = None
|
|
) -> config_entries.ConfigFlowResult:
|
|
"""Handle a flow start."""
|
|
return await self.async_step_pick_implementation(user_input)
|
|
|
|
@classmethod
|
|
def async_register_implementation(
|
|
cls, hass: HomeAssistant, local_impl: LocalOAuth2Implementation
|
|
) -> None:
|
|
"""Register a local implementation."""
|
|
async_register_implementation(hass, cls.DOMAIN, local_impl)
|
|
|
|
|
|
@callback
|
|
def async_register_implementation(
|
|
hass: HomeAssistant, domain: str, implementation: AbstractOAuth2Implementation
|
|
) -> None:
|
|
"""Register an OAuth2 flow implementation for an integration."""
|
|
implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {})
|
|
implementations.setdefault(domain, {})[implementation.domain] = implementation
|
|
|
|
|
|
async def async_get_implementations(
|
|
hass: HomeAssistant, domain: str
|
|
) -> dict[str, AbstractOAuth2Implementation]:
|
|
"""Return OAuth2 implementations for specified domain."""
|
|
registered = hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {})
|
|
|
|
if DATA_PROVIDERS not in hass.data:
|
|
return registered
|
|
|
|
registered = dict(registered)
|
|
for get_impl in list(hass.data[DATA_PROVIDERS].values()):
|
|
for impl in await get_impl(hass, domain):
|
|
registered[impl.domain] = impl
|
|
|
|
return registered
|
|
|
|
|
|
async def async_get_config_entry_implementation(
|
|
hass: HomeAssistant, config_entry: config_entries.ConfigEntry
|
|
) -> AbstractOAuth2Implementation:
|
|
"""Return the implementation for this config entry."""
|
|
implementations = await async_get_implementations(hass, config_entry.domain)
|
|
implementation = implementations.get(config_entry.data["auth_implementation"])
|
|
|
|
if implementation is None:
|
|
raise ValueError("Implementation not available")
|
|
|
|
return implementation
|
|
|
|
|
|
@callback
|
|
def async_add_implementation_provider(
|
|
hass: HomeAssistant,
|
|
provider_domain: str,
|
|
async_provide_implementation: Callable[
|
|
[HomeAssistant, str], Awaitable[list[AbstractOAuth2Implementation]]
|
|
],
|
|
) -> None:
|
|
"""Add an implementation provider.
|
|
|
|
If no implementation found, return None.
|
|
"""
|
|
hass.data.setdefault(DATA_PROVIDERS, {})[provider_domain] = (
|
|
async_provide_implementation
|
|
)
|
|
|
|
|
|
class OAuth2AuthorizeCallbackView(http.HomeAssistantView):
|
|
"""OAuth2 Authorization Callback View."""
|
|
|
|
requires_auth = False
|
|
url = AUTH_CALLBACK_PATH
|
|
name = "auth:external:callback"
|
|
|
|
async def get(self, request: web.Request) -> web.Response:
|
|
"""Receive authorization code."""
|
|
if "state" not in request.query:
|
|
return web.Response(text="Missing state parameter")
|
|
|
|
hass = request.app[http.KEY_HASS]
|
|
|
|
state = _decode_jwt(hass, request.query["state"])
|
|
|
|
if state is None:
|
|
return web.Response(
|
|
text=(
|
|
"Invalid state. Is My Home Assistant configured "
|
|
"to go to the right instance?"
|
|
),
|
|
status=400,
|
|
)
|
|
|
|
user_input: dict[str, Any] = {"state": state}
|
|
|
|
if "code" in request.query:
|
|
user_input["code"] = request.query["code"]
|
|
elif "error" in request.query:
|
|
user_input["error"] = request.query["error"]
|
|
else:
|
|
return web.Response(text="Missing code or error parameter")
|
|
|
|
await hass.config_entries.flow.async_configure(
|
|
flow_id=state["flow_id"], user_input=user_input
|
|
)
|
|
_LOGGER.debug("Resumed OAuth configuration flow")
|
|
return web.Response(
|
|
headers={"content-type": "text/html"},
|
|
text="<script>window.close()</script>",
|
|
)
|
|
|
|
|
|
class OAuth2Session:
|
|
"""Session to make requests authenticated with OAuth2."""
|
|
|
|
def __init__(
|
|
self,
|
|
hass: HomeAssistant,
|
|
config_entry: config_entries.ConfigEntry,
|
|
implementation: AbstractOAuth2Implementation,
|
|
) -> None:
|
|
"""Initialize an OAuth2 session."""
|
|
self.hass = hass
|
|
self.config_entry = config_entry
|
|
self.implementation = implementation
|
|
self._token_lock = Lock()
|
|
|
|
@property
|
|
def token(self) -> dict:
|
|
"""Return the token."""
|
|
return cast(dict, self.config_entry.data["token"])
|
|
|
|
@property
|
|
def valid_token(self) -> bool:
|
|
"""Return if token is still valid."""
|
|
return (
|
|
cast(float, self.token["expires_at"])
|
|
> time.time() + CLOCK_OUT_OF_SYNC_MAX_SEC
|
|
)
|
|
|
|
async def async_ensure_token_valid(self) -> None:
|
|
"""Ensure that the current token is valid."""
|
|
async with self._token_lock:
|
|
if self.valid_token:
|
|
return
|
|
|
|
new_token = await self.implementation.async_refresh_token(self.token)
|
|
|
|
self.hass.config_entries.async_update_entry(
|
|
self.config_entry, data={**self.config_entry.data, "token": new_token}
|
|
)
|
|
|
|
async def async_request(
|
|
self, method: str, url: str, **kwargs: Any
|
|
) -> client.ClientResponse:
|
|
"""Make a request."""
|
|
await self.async_ensure_token_valid()
|
|
return await async_oauth2_request(
|
|
self.hass, self.config_entry.data["token"], method, url, **kwargs
|
|
)
|
|
|
|
|
|
async def async_oauth2_request(
|
|
hass: HomeAssistant, token: dict, method: str, url: str, **kwargs: Any
|
|
) -> client.ClientResponse:
|
|
"""Make an OAuth2 authenticated request.
|
|
|
|
This method will not refresh tokens. Use OAuth2 session for that.
|
|
"""
|
|
session = async_get_clientsession(hass)
|
|
headers = kwargs.pop("headers", {})
|
|
return await session.request(
|
|
method,
|
|
url,
|
|
**kwargs,
|
|
headers={
|
|
**headers,
|
|
"authorization": f"Bearer {token['access_token']}",
|
|
},
|
|
)
|
|
|
|
|
|
@callback
|
|
def _encode_jwt(hass: HomeAssistant, data: dict) -> str:
|
|
"""JWT encode data."""
|
|
if (secret := hass.data.get(DATA_JWT_SECRET)) is None:
|
|
secret = hass.data[DATA_JWT_SECRET] = secrets.token_hex()
|
|
|
|
return jwt.encode(data, secret, algorithm="HS256")
|
|
|
|
|
|
@callback
|
|
def _decode_jwt(hass: HomeAssistant, encoded: str) -> dict[str, Any] | None:
|
|
"""JWT encode data."""
|
|
secret: str | None = hass.data.get(DATA_JWT_SECRET)
|
|
|
|
if secret is None:
|
|
return None
|
|
|
|
try:
|
|
return jwt.decode(encoded, secret, algorithms=["HS256"]) # type: ignore[no-any-return]
|
|
except jwt.InvalidTokenError:
|
|
return None
|