fix merge conflict
This commit is contained in:
parent
4d9fba0587
commit
b73868bc24
|
@ -1,12 +1,10 @@
|
|||
# Standard library imports...
|
||||
import logging
|
||||
from requests import exceptions
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
# Third-party imports...
|
||||
from datetime import datetime
|
||||
from flask import (
|
||||
Blueprint,
|
||||
abort,
|
||||
|
@ -17,21 +15,24 @@ from flask import (
|
|||
send_file,
|
||||
)
|
||||
from flask_httpauth import HTTPTokenAuth
|
||||
from requests import exceptions
|
||||
from werkzeug.exceptions import BadRequest
|
||||
from wtforms import Form, StringField, PasswordField, validators
|
||||
|
||||
# Local imports...
|
||||
from .matrix_api import create_account
|
||||
from . import config
|
||||
from . import tokens
|
||||
from .constants import __location__
|
||||
from .limiter import limiter, get_default_rate_limit
|
||||
from .matrix_api import create_account
|
||||
from .translation import get_translations
|
||||
|
||||
|
||||
auth = HTTPTokenAuth(scheme="SharedSecret")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
api = Blueprint("api", __name__)
|
||||
healthcheck = Blueprint("healthcheck", __name__)
|
||||
limiter.limit(get_default_rate_limit)(api)
|
||||
|
||||
|
||||
def validate_token(form, token):
|
||||
|
@ -70,7 +71,7 @@ def validate_username(form, username):
|
|||
ValidationError
|
||||
Username doesn't follow mxid requirements
|
||||
"""
|
||||
re_mxid = f"^(?P<at>@)?(?P<username>[a-zA-Z_\-=\.\/0-9]+)(?P<server_name>:{ re.escape(config.config.server_name) })?$"
|
||||
re_mxid = f"^(?P<at>@)?(?P<username>[a-zA-Z_\-=\.\/0-9]+)(?P<server_name>:{re.escape(config.config.server_name)})?$"
|
||||
match = re.search(re_mxid, username.data)
|
||||
if not match:
|
||||
raise validators.ValidationError(
|
||||
|
@ -156,7 +157,7 @@ def get_request_ips(request):
|
|||
@auth.verify_token
|
||||
def verify_token(token):
|
||||
return (
|
||||
token != "APIAdminPassword" and token == config.config.admin_api_shared_secret
|
||||
token != "APIAdminPassword" and token == config.config.admin_api_shared_secret
|
||||
)
|
||||
|
||||
|
||||
|
@ -329,7 +330,7 @@ def delete_token(token):
|
|||
return make_response(jsonify(resp), 500)
|
||||
|
||||
|
||||
@api.route("/health")
|
||||
@healthcheck.route("/health")
|
||||
def health():
|
||||
return make_response("OK", 200)
|
||||
|
||||
|
|
|
@ -1,19 +1,18 @@
|
|||
import json
|
||||
import logging
|
||||
import logging.config
|
||||
import click
|
||||
import json
|
||||
import os
|
||||
|
||||
import click
|
||||
from flask import Flask
|
||||
from flask.cli import FlaskGroup, pass_script_info
|
||||
from flask_limiter import Limiter
|
||||
from flask import request
|
||||
from flask_cors import CORS
|
||||
from waitress import serve
|
||||
|
||||
from . import config
|
||||
from . import tokens
|
||||
from .limiter import limiter
|
||||
from .tokens import db
|
||||
import os
|
||||
|
||||
|
||||
def create_app(testing=False):
|
||||
|
@ -21,10 +20,12 @@ def create_app(testing=False):
|
|||
app.testing = testing
|
||||
|
||||
with app.app_context():
|
||||
from .api import api
|
||||
from .api import api, healthcheck
|
||||
|
||||
app.register_blueprint(api)
|
||||
app.register_blueprint(healthcheck)
|
||||
|
||||
limiter.init_app(app)
|
||||
return app
|
||||
|
||||
|
||||
|
@ -51,16 +52,10 @@ def cli(info, config_path):
|
|||
tokens.tokens = tokens.Tokens()
|
||||
|
||||
|
||||
def get_real_user_ip() -> str:
|
||||
"""ratelimit the users original ip instead of (optional) reverse proxy"""
|
||||
return next(iter(request.headers.getlist("X-Forwarded-For")), request.remote_addr)
|
||||
|
||||
|
||||
@cli.command("serve", help="start api server")
|
||||
@pass_script_info
|
||||
def run_server(info):
|
||||
app = info.load_app()
|
||||
Limiter(app, key_func=get_real_user_ip, default_limits=config.config.rate_limit)
|
||||
if config.config.allow_cors:
|
||||
CORS(app)
|
||||
serve(
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
from flask import request
|
||||
from flask_limiter import Limiter
|
||||
|
||||
from . import config
|
||||
|
||||
|
||||
def get_real_user_ip() -> str:
|
||||
"""ratelimit the users original ip instead of (optional) reverse proxy"""
|
||||
return next(iter(request.headers.getlist('X-Forwarded-For')), request.remote_addr)
|
||||
|
||||
|
||||
def get_default_rate_limit() -> str:
|
||||
"""return limit_string"""
|
||||
return '; '.join(config.config.rate_limit)
|
||||
|
||||
|
||||
limiter = Limiter(key_func=get_real_user_ip)
|
|
@ -2,26 +2,22 @@
|
|||
# Standard library imports...
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import logging.config
|
||||
import json
|
||||
import logging.config
|
||||
import os
|
||||
import yaml
|
||||
import random
|
||||
import re
|
||||
import requests
|
||||
from requests import exceptions
|
||||
import string
|
||||
import sys
|
||||
import unittest
|
||||
from datetime import datetime
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
# Third-party imports...
|
||||
import yaml
|
||||
from parameterized import parameterized
|
||||
from datetime import datetime
|
||||
from click.testing import CliRunner
|
||||
from flask import Flask
|
||||
from requests import exceptions
|
||||
|
||||
# Local imports...
|
||||
try:
|
||||
|
@ -29,7 +25,6 @@ try:
|
|||
except ModuleNotFoundError:
|
||||
from context import matrix_registration
|
||||
from matrix_registration.config import Config
|
||||
from matrix_registration.app import create_app
|
||||
from matrix_registration.tokens import db
|
||||
from matrix_registration.app import (
|
||||
create_app,
|
||||
|
@ -38,7 +33,6 @@ from matrix_registration.app import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
LOGGING = {
|
||||
"version": 1,
|
||||
"root": {"level": "NOTSET", "handlers": ["console"]},
|
||||
|
@ -66,7 +60,7 @@ GOOD_CONFIG = {
|
|||
"db": "sqlite:///%s/tests/db.sqlite" % (os.getcwd(),),
|
||||
"host": "",
|
||||
"port": 5000,
|
||||
"rate_limit": ["100 per day", "10 per minute"],
|
||||
"rate_limit": ["1000 per day", "100 per minute"],
|
||||
"allow_cors": False,
|
||||
"password": {"min_length": 8},
|
||||
"username": {
|
||||
|
@ -390,11 +384,11 @@ class TokensTest(unittest.TestCase):
|
|||
test_token5 = test_tokens.new()
|
||||
|
||||
expected_answer = (
|
||||
"%s, " % test_token1.name
|
||||
+ "%s, " % test_token2.name
|
||||
+ "%s, " % test_token3.name
|
||||
+ "%s, " % test_token4.name
|
||||
+ "%s" % test_token5.name
|
||||
"%s, " % test_token1.name
|
||||
+ "%s, " % test_token2.name
|
||||
+ "%s, " % test_token3.name
|
||||
+ "%s, " % test_token4.name
|
||||
+ "%s" % test_token5.name
|
||||
)
|
||||
|
||||
self.assertEqual(str(test_tokens), expected_answer)
|
||||
|
@ -448,7 +442,7 @@ class ApiTest(unittest.TestCase):
|
|||
"matrix_registration.matrix_api.requests.post", side_effect=mocked_requests_post
|
||||
)
|
||||
def test_register(
|
||||
self, username, password, confirm, token, status, mock_get, mock_nonce
|
||||
self, username, password, confirm, token, status, mock_get, mock_nonce
|
||||
):
|
||||
matrix_registration.config.config = Config(data=GOOD_CONFIG)
|
||||
with self.app.app_context():
|
||||
|
@ -837,6 +831,26 @@ class ApiTest(unittest.TestCase):
|
|||
self.assertEqual(token_data["errcode"], "MR_BAD_SECRET")
|
||||
self.assertEqual(token_data["error"], "wrong shared secret")
|
||||
|
||||
def test_rate_limit_exempt(self):
|
||||
matrix_registration.config.config = Config(GOOD_CONFIG)
|
||||
|
||||
with self.app.app_context():
|
||||
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
|
||||
secret = matrix_registration.config.config.admin_api_shared_secret
|
||||
headers = {"Authorization": "SharedSecret %s" % secret}
|
||||
|
||||
for i in range(110):
|
||||
self.client.get("/api/token", headers=headers)
|
||||
|
||||
rv = self.client.get("/api/token", headers=headers)
|
||||
self.assertEqual(rv.status_code, 429)
|
||||
|
||||
for i in range(110):
|
||||
self.client.get("/health")
|
||||
|
||||
rv = self.client.get("/health")
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
|
||||
class ConfigTest(unittest.TestCase):
|
||||
def test_config_update(self):
|
||||
|
|
Loading…
Reference in New Issue