matrix-registration/matrix_registration/api.py

379 lines
11 KiB
Python

# Standard library imports...
import logging
import os
import re
from datetime import datetime
# Third-party imports...
from flask import (
Blueprint,
abort,
jsonify,
request,
make_response,
render_template,
send_file,
)
from flask_httpauth import HTTPTokenAuth
from requests import exceptions
from werkzeug.exceptions import BadRequest
from wtforms import Form, StringField, PasswordField, validators
# Local imports...
from . import config
from . import tokens
from .constants import __location__
from .limiter import limiter, get_default_rate_limit
from .matrix_api import create_account
from .translation import get_translations
auth = HTTPTokenAuth(scheme="SharedSecret")
logger = logging.getLogger(__name__)
api = Blueprint("api", __name__)
healthcheck = Blueprint("healthcheck", __name__)
limiter.limit(get_default_rate_limit)(api)
def validate_token(form, token):
"""
validates token
Parameters
----------
arg1 : Form object
arg2 : str
token name, e.g. 'DoubleWizardSki'
Raises
-------
ValidationError
Token is invalid
"""
tokens.tokens.load()
if not tokens.tokens.active(token.data):
raise validators.ValidationError("Token is invalid")
def validate_username(form, username):
"""
validates username
Parameters
----------
arg1 : Form object
arg2 : str
username name, e.g: '@user:matrix.org' or 'user'
https://github.com/matrix-org/matrix-doc/blob/master/specification/appendices/identifier_grammar.rst#user-identifiers
Raises
-------
ValidationError
Username doesn't follow mxid requirements
"""
re_mxid = f"^(?P<at>@)?(?P<username>[a-zA-Z_\-=\.\/0-9]+)(?P<server_name>:{re.escape(config.config.server_name)})?$"
match = re.search(re_mxid, username.data)
if not match:
raise validators.ValidationError(
f"Username doesn't follow mxid pattern: /{re_mxid}/"
)
username = match.group("username")
for e in [
validators.ValidationError(f"Username does not follow custom pattern /{x}/")
for x in config.config.username["validation_regex"]
if not re.search(x, username)
]:
raise e
for e in [
validators.ValidationError(f"Username must not follow custom pattern /{x}/")
for x in config.config.username["invalidation_regex"]
if re.search(x, username)
]:
raise e
def validate_password(form, password):
"""
validates username
Parameters
----------
arg1 : Form object
arg2 : str
password
Raises
-------
ValidationError
Password doesn't follow length requirements
"""
min_length = config.config.password["min_length"]
err = "Password should be between %s and 255 chars long" % min_length
if len(password.data) < min_length or len(password.data) > 255:
raise validators.ValidationError(err)
class RegistrationForm(Form):
"""
Registration Form
validates user account registration requests
"""
username = StringField(
"Username",
[
validators.Length(min=1, max=200),
# validators.Regexp(re_mxid)
validate_username,
],
)
password = PasswordField(
"New Password",
[
# validators.Length(min=8),
validate_password,
validators.DataRequired(),
validators.EqualTo("confirm", message="Passwords must match"),
],
)
confirm = PasswordField("Repeat Password")
token = StringField(
"Token", [validators.Regexp(r"^([A-Z][a-z]+)+$"), validate_token]
)
def get_request_ips(request):
"""
Get the chain of client and proxy IP addresses from the request as
a nonempty list, where the closest IP in the chain is last. Each
IP vouches only for the IP before it. This works best if all proxies
conform the to the X-Forwarded-For header spec, including whatever
reverse proxy (such as nginx) is directly in front of the app, if any.
(X-Real-IP and similar are not supported at this time.)
"""
return request.headers.getlist("X-Forwarded-For") + [request.remote_addr]
@auth.verify_token
def verify_token(token):
return (
token != "APIAdminPassword" and token == config.config.admin_api_shared_secret
)
@auth.error_handler
def unauthorized():
resp = {"errcode": "MR_BAD_SECRET", "error": "wrong shared secret"}
return make_response(jsonify(resp), 401)
@api.route("/static/replace/images/element-logo.png")
def element_logo():
return send_file(
config.config.client_logo.replace("{cwd}", f"{os.getcwd()}/"),
mimetype="image/jpeg",
)
@api.route("/register", methods=["GET", "POST"])
def register():
"""
main user account registration endpoint
to register an account you need to send a
application/x-www-form-urlencoded request with
- username
- password
- confirm
- token
as described in the RegistrationForm
"""
if request.method == "POST":
logger.debug("an account registration started...")
form = RegistrationForm(request.form)
logger.debug("validating request data...")
if form.validate():
logger.debug("request valid")
return create_account_from_form(form)
logger.debug("account creation failed!")
resp = {"errcode": "MR_BAD_USER_REQUEST", "error": form.errors}
return make_response(jsonify(resp), 400)
# GET REQUEST
server_name = config.config.server_name
pw_length = config.config.password["min_length"]
uname_regex = config.config.username["validation_regex"]
uname_regex_inv = config.config.username["invalidation_regex"]
lang = request.args.get("lang") or request.accept_languages.best
replacements = {"server_name": server_name, "pw_length": pw_length}
translations = get_translations(lang, replacements)
return render_template(
"register.html",
server_name=server_name,
pw_length=pw_length,
uname_regex=uname_regex,
uname_regex_inv=uname_regex_inv,
client_redirect=config.config.client_redirect,
base_url=config.config.base_url,
translations=translations,
)
def create_account_from_form(form):
# remove sigil and the domain from the username
username = form.username.data.rsplit(":")[0].split("@")[-1]
logger.debug("creating account %s..." % username)
# send account creation request to the hs
try:
account_data = create_account(
form.username.data,
form.password.data,
config.config.server_location,
config.config.registration_shared_secret,
)
except exceptions.ConnectionError:
logger.error(
"can not connect to %s" % config.config.server_location,
exc_info=True,
)
abort(500)
except exceptions.HTTPError as e:
resp = e.response
error = resp.json()
status_code = resp.status_code
if status_code == 404:
logger.error("no HS found at %s" % config.config.server_location)
elif status_code == 403:
logger.error("wrong shared registration secret or not enabled")
elif status_code == 400:
# most likely this should only be triggered if a userid
# is already in use
return make_response(jsonify(error), 400)
else:
logger.error("failure communicating with HS", exc_info=True)
abort(500)
logger.debug("using token %s" % form.token.data)
ips = ", ".join(get_request_ips(request)) if config.config.ip_logging else False
tokens.tokens.use(form.token.data, ips)
logger.debug("account creation succeded!")
return jsonify(
access_token=account_data["access_token"],
home_server=account_data["home_server"],
user_id=account_data["user_id"],
status="success",
status_code=200,
)
def get_token(token):
if tokens.tokens.get_token(token):
return jsonify(tokens.tokens.get_token(token).toDict())
resp = {"errcode": "MR_TOKEN_NOT_FOUND", "error": "token does not exist"}
return make_response(jsonify(resp), 404)
def get_tokens():
return jsonify(tokens.tokens.toList())
def create_token(data):
if not data:
resp = {
"errcode": "MR_BAD_USER_REQUEST",
"error": "no data was sent",
}
return make_response(jsonify(resp), 400)
max_usage = False
expiration_date = None
try:
if "expiration_date" in data and data["expiration_date"] is not None:
expiration_date = datetime.fromisoformat(data["expiration_date"])
if "max_usage" in data:
max_usage = data["max_usage"]
token = tokens.tokens.new(expiration_date=expiration_date, max_usage=max_usage)
except ValueError:
resp = {
"errcode": "MR_BAD_DATE_FORMAT",
"error": "date wasn't in YYYY-MM-DD format",
}
return make_response(jsonify(resp), 400)
return jsonify(token.toDict())
def update_token(token, data):
if "ips" in data or "active" in data or "name" in data:
resp = {
"errcode": "MR_BAD_USER_REQUEST",
"error": "you're not allowed to change this property",
}
return make_response(jsonify(resp), 400)
if tokens.tokens.update(token, data):
return jsonify(tokens.tokens.get_token(token).toDict())
resp = {"errcode": "MR_TOKEN_NOT_FOUND", "error": "token does not exist"}
return make_response(jsonify(resp), 404)
def delete_token(token):
if not tokens.tokens.get_token(token):
resp = {"errcode": "MR_TOKEN_NOT_FOUND", "error": "token does not exist"}
return (jsonify(resp), 404)
if tokens.tokens.delete(token):
resp = {"success": "true"}
return make_response(jsonify(resp), 200)
resp = {"success": "false"}
return make_response(jsonify(resp), 500)
@healthcheck.route("/health")
def health():
return make_response("OK", 200)
@api.route("/api/version")
@auth.login_required
def version():
with open(os.path.join(__location__, "__init__.py"), "r") as file:
version_file = file.read()
version_match = re.search(
r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M
)
resp = {"version": version_match.group(1)}
return make_response(jsonify(resp), 200)
@api.route("/api/token", methods=["GET", "POST"])
@auth.login_required
def token():
tokens.tokens.load()
if request.method == "GET":
return get_tokens()
elif request.method == "POST":
return create_token(request.get_json())
resp = {"errcode": "MR_BAD_USER_REQUEST", "error": "malformed request"}
return make_response(jsonify(resp), 400)
@api.route("/api/token/<token>", methods=["GET", "PATCH", "DELETE"])
@auth.login_required
def token_status(token):
tokens.tokens.load()
data = False
if request.method == "GET":
return get_token(token)
elif request.method == "PATCH":
return update_token(token, request.get_json())
elif request.method == "DELETE":
return delete_token(token)
resp = {"errcode": "MR_BAD_USER_REQUEST", "error": "malformed request"}
return make_response(jsonify(resp), 400)