fix merge conflict

This commit is contained in:
ZerataX 2023-01-11 16:08:18 +00:00
parent 4d9fba0587
commit b73868bc24
No known key found for this signature in database
GPG Key ID: 8333735E784DF9D4
4 changed files with 65 additions and 38 deletions

View File

@ -1,12 +1,10 @@
# Standard library imports...
import logging
from requests import exceptions
import re
from urllib.parse import urlparse
import os
import re
from datetime import datetime
# Third-party imports...
from datetime import datetime
from flask import (
Blueprint,
abort,
@ -17,21 +15,24 @@ from flask import (
send_file,
)
from flask_httpauth import HTTPTokenAuth
from requests import exceptions
from werkzeug.exceptions import BadRequest
from wtforms import Form, StringField, PasswordField, validators
# Local imports...
from .matrix_api import create_account
from . import config
from . import tokens
from .constants import __location__
from .limiter import limiter, get_default_rate_limit
from .matrix_api import create_account
from .translation import get_translations
auth = HTTPTokenAuth(scheme="SharedSecret")
logger = logging.getLogger(__name__)
api = Blueprint("api", __name__)
healthcheck = Blueprint("healthcheck", __name__)
limiter.limit(get_default_rate_limit)(api)
def validate_token(form, token):
@ -70,7 +71,7 @@ def validate_username(form, username):
ValidationError
Username doesn't follow mxid requirements
"""
re_mxid = f"^(?P<at>@)?(?P<username>[a-zA-Z_\-=\.\/0-9]+)(?P<server_name>:{ re.escape(config.config.server_name) })?$"
re_mxid = f"^(?P<at>@)?(?P<username>[a-zA-Z_\-=\.\/0-9]+)(?P<server_name>:{re.escape(config.config.server_name)})?$"
match = re.search(re_mxid, username.data)
if not match:
raise validators.ValidationError(
@ -156,7 +157,7 @@ def get_request_ips(request):
@auth.verify_token
def verify_token(token):
return (
token != "APIAdminPassword" and token == config.config.admin_api_shared_secret
token != "APIAdminPassword" and token == config.config.admin_api_shared_secret
)
@ -329,7 +330,7 @@ def delete_token(token):
return make_response(jsonify(resp), 500)
@api.route("/health")
@healthcheck.route("/health")
def health():
return make_response("OK", 200)

View File

@ -1,19 +1,18 @@
import json
import logging
import logging.config
import click
import json
import os
import click
from flask import Flask
from flask.cli import FlaskGroup, pass_script_info
from flask_limiter import Limiter
from flask import request
from flask_cors import CORS
from waitress import serve
from . import config
from . import tokens
from .limiter import limiter
from .tokens import db
import os
def create_app(testing=False):
@ -21,10 +20,12 @@ def create_app(testing=False):
app.testing = testing
with app.app_context():
from .api import api
from .api import api, healthcheck
app.register_blueprint(api)
app.register_blueprint(healthcheck)
limiter.init_app(app)
return app
@ -51,16 +52,10 @@ def cli(info, config_path):
tokens.tokens = tokens.Tokens()
def get_real_user_ip() -> str:
"""ratelimit the users original ip instead of (optional) reverse proxy"""
return next(iter(request.headers.getlist("X-Forwarded-For")), request.remote_addr)
@cli.command("serve", help="start api server")
@pass_script_info
def run_server(info):
app = info.load_app()
Limiter(app, key_func=get_real_user_ip, default_limits=config.config.rate_limit)
if config.config.allow_cors:
CORS(app)
serve(

View File

@ -0,0 +1,17 @@
from flask import request
from flask_limiter import Limiter
from . import config
def get_real_user_ip() -> str:
"""ratelimit the users original ip instead of (optional) reverse proxy"""
return next(iter(request.headers.getlist('X-Forwarded-For')), request.remote_addr)
def get_default_rate_limit() -> str:
"""return limit_string"""
return '; '.join(config.config.rate_limit)
limiter = Limiter(key_func=get_real_user_ip)

View File

@ -2,26 +2,22 @@
# Standard library imports...
import hashlib
import hmac
import logging
import logging.config
import json
import logging.config
import os
import yaml
import random
import re
import requests
from requests import exceptions
import string
import sys
import unittest
from datetime import datetime
from unittest.mock import patch
from urllib.parse import urlparse
# Third-party imports...
import yaml
from parameterized import parameterized
from datetime import datetime
from click.testing import CliRunner
from flask import Flask
from requests import exceptions
# Local imports...
try:
@ -29,7 +25,6 @@ try:
except ModuleNotFoundError:
from context import matrix_registration
from matrix_registration.config import Config
from matrix_registration.app import create_app
from matrix_registration.tokens import db
from matrix_registration.app import (
create_app,
@ -38,7 +33,6 @@ from matrix_registration.app import (
logger = logging.getLogger(__name__)
LOGGING = {
"version": 1,
"root": {"level": "NOTSET", "handlers": ["console"]},
@ -66,7 +60,7 @@ GOOD_CONFIG = {
"db": "sqlite:///%s/tests/db.sqlite" % (os.getcwd(),),
"host": "",
"port": 5000,
"rate_limit": ["100 per day", "10 per minute"],
"rate_limit": ["1000 per day", "100 per minute"],
"allow_cors": False,
"password": {"min_length": 8},
"username": {
@ -390,11 +384,11 @@ class TokensTest(unittest.TestCase):
test_token5 = test_tokens.new()
expected_answer = (
"%s, " % test_token1.name
+ "%s, " % test_token2.name
+ "%s, " % test_token3.name
+ "%s, " % test_token4.name
+ "%s" % test_token5.name
"%s, " % test_token1.name
+ "%s, " % test_token2.name
+ "%s, " % test_token3.name
+ "%s, " % test_token4.name
+ "%s" % test_token5.name
)
self.assertEqual(str(test_tokens), expected_answer)
@ -448,7 +442,7 @@ class ApiTest(unittest.TestCase):
"matrix_registration.matrix_api.requests.post", side_effect=mocked_requests_post
)
def test_register(
self, username, password, confirm, token, status, mock_get, mock_nonce
self, username, password, confirm, token, status, mock_get, mock_nonce
):
matrix_registration.config.config = Config(data=GOOD_CONFIG)
with self.app.app_context():
@ -837,6 +831,26 @@ class ApiTest(unittest.TestCase):
self.assertEqual(token_data["errcode"], "MR_BAD_SECRET")
self.assertEqual(token_data["error"], "wrong shared secret")
def test_rate_limit_exempt(self):
matrix_registration.config.config = Config(GOOD_CONFIG)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
for i in range(110):
self.client.get("/api/token", headers=headers)
rv = self.client.get("/api/token", headers=headers)
self.assertEqual(rv.status_code, 429)
for i in range(110):
self.client.get("/health")
rv = self.client.get("/health")
self.assertEqual(rv.status_code, 200)
class ConfigTest(unittest.TestCase):
def test_config_update(self):