matrix-registration/tests/test_registration.py

962 lines
36 KiB
Python

# -*- coding: utf-8 -*-
# Standard library imports...
import hashlib
import hmac
import json
import logging.config
import os
import random
import re
import string
import sys
import unittest
from datetime import datetime
from unittest.mock import patch
from urllib.parse import urlparse
# Third-party imports...
import yaml
from parameterized import parameterized
from requests import exceptions
# Local imports...
try:
from .context import matrix_registration
except ModuleNotFoundError:
from context import matrix_registration
from matrix_registration.config import Config
from matrix_registration.tokens import db
from matrix_registration.app import (
create_app,
cli,
)
logger = logging.getLogger(__name__)
LOGGING = {
"version": 1,
"root": {"level": "NOTSET", "handlers": ["console"]},
"formatters": {
"precise": {"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s"}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": "NOTSET",
"formatter": "precise",
"stream": "ext://sys.stdout",
}
},
}
GOOD_CONFIG = {
"server_location": "https://matrix.org",
"server_name": "matrix.org",
"registration_shared_secret": "coolsharesecret",
"admin_api_shared_secret": "coolpassword",
"base_url": "/element",
"client_redirect": "",
"client_logo": "",
"db": "sqlite:///%s/tests/db.sqlite" % (os.getcwd(),),
"host": "",
"port": 5000,
"rate_limit": ["1000 per day", "100 per minute"],
"allow_cors": False,
"password": {"min_length": 8},
"username": {
"validation_regex": ["[a-z\d]"],
"invalidation_regex": [".*?(admin|support).*?"],
},
"ip_logging": False,
"logging": LOGGING,
}
BAD_CONFIG1 = dict( # wrong matrix server location -> 500
GOOD_CONFIG.items(),
server_location="https://wronghs.org",
)
BAD_CONFIG2 = dict( # wrong admin secret password -> 401
GOOD_CONFIG.items(),
admin_api_shared_secret="wrongpassword",
)
BAD_CONFIG3 = dict( # wrong matrix shared password -> 500
GOOD_CONFIG.items(),
registration_shared_secret="wrongsecret",
)
usernames = []
nonces = []
logging.config.dictConfig(LOGGING)
def mock_new_user(username):
access_token = "".join(
random.choices(string.ascii_lowercase + string.digits, k=256)
)
device_id = "".join(random.choices(string.ascii_uppercase, k=8))
home_server = matrix_registration.config.config.server_location
username = username.rsplit(":")[0].split("@")[-1]
user_id = "@{}:{}".format(username, home_server)
usernames.append(username)
user = {
"access_token": access_token,
"device_id": device_id,
"home_server": home_server,
"user_id": user_id,
}
return user
def mocked__get_nonce(server_location):
nonce = "".join(random.choices(string.ascii_lowercase + string.digits, k=129))
nonces.append(nonce)
return nonce
def mocked_requests_post(*args, **kwargs):
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code
def json(self):
return self.json_data
def raise_for_status(self):
if self.status_code == 200:
return self.status_code
else:
raise exceptions.HTTPError(response=self)
# print(args[0])
# print(matrix_registration.config.config.server_location)
domain = urlparse(GOOD_CONFIG["server_location"]).hostname
re_mxid = r"^@?[a-zA-Z_\-=\.\/0-9]+(:" + re.escape(domain) + r")?$"
location = "_synapse/admin/v1/register"
if args[0] == "%s/%s" % (GOOD_CONFIG["server_location"], location):
if kwargs:
req = kwargs["json"]
if not req["nonce"] in nonces:
return MockResponse(
{"'errcode': 'M_UNKOWN", "'error': 'unrecognised nonce'"}, 400
)
mac = hmac.new(
key=str.encode(GOOD_CONFIG["registration_shared_secret"]),
digestmod=hashlib.sha1,
)
mac.update(req["nonce"].encode())
mac.update(b"\x00")
mac.update(req["username"].encode())
mac.update(b"\x00")
mac.update(req["password"].encode())
mac.update(b"\x00")
mac.update(b"admin" if req["admin"] else b"notadmin")
mac = mac.hexdigest()
if not re.search(re_mxid, req["username"]):
return MockResponse(
{
"'errcode': 'M_INVALID_USERNAME",
"'error': 'User ID can only contain"
+ "characters a-z, 0-9, or '=_-./'",
},
400,
)
if req["username"].rsplit(":")[0].split("@")[-1] in usernames:
return MockResponse(
{"errcode": "M_USER_IN_USE", "error": "User ID already taken."}, 400
)
if req["mac"] != mac:
return MockResponse(
{"errcode": "M_UNKNOWN", "error": "HMAC incorrect"}, 403
)
return MockResponse(mock_new_user(req["username"]), 200)
return MockResponse(None, 404)
class TokensTest(unittest.TestCase):
def setUp(self):
matrix_registration.config.config = Config(data=GOOD_CONFIG)
app = create_app(testing=True)
with app.app_context():
app.config.from_mapping(
SQLALCHEMY_DATABASE_URI=matrix_registration.config.config.db,
SQLALCHEMY_TRACK_MODIFICATIONS=False,
)
db.init_app(app)
db.create_all()
self.app = app
def tearDown(self):
os.remove(matrix_registration.config.config.db[10:])
def test_random_readable_string(self):
for n in range(10):
string = matrix_registration.tokens.random_readable_string(length=n)
words = re.sub("([a-z])([A-Z])", r"\1 \2", string).split()
self.assertEqual(len(words), n)
def test_tokens_empty(self):
with self.app.app_context():
test_tokens = matrix_registration.tokens.Tokens()
# no token should exist at this point
self.assertFalse(test_tokens.active(""))
test_token = test_tokens.new()
# no empty token should have been created
self.assertFalse(test_tokens.active(""))
def test_tokens_disable(self):
with self.app.app_context():
test_tokens = matrix_registration.tokens.Tokens()
test_token = test_tokens.new()
# new tokens should be active first, inactive after disabling it
self.assertTrue(test_token.active())
self.assertTrue(test_token.disable())
self.assertFalse(test_token.active())
test_token2 = test_tokens.new()
self.assertTrue(test_tokens.active(test_token2.name))
self.assertTrue(test_tokens.disable(test_token2.name))
self.assertFalse(test_tokens.active(test_token2.name))
test_token3 = test_tokens.new()
test_token3.use()
self.assertFalse(test_tokens.active(test_token2.name))
self.assertFalse(test_tokens.disable(test_token2.name))
self.assertFalse(test_tokens.active(test_token2.name))
def test_tokens_load(self):
with self.app.app_context():
test_tokens = matrix_registration.tokens.Tokens()
test_token = test_tokens.new()
test_token2 = test_tokens.new()
test_token3 = test_tokens.new(max_usage=True)
test_token4 = test_tokens.new(
expiration_date=datetime.fromisoformat("2111-01-01")
)
test_token5 = test_tokens.new(
expiration_date=datetime.fromisoformat("1999-01-01")
)
test_tokens.disable(test_token2.name)
test_tokens.use(test_token3.name)
test_tokens.use(test_token4.name)
test_tokens.load()
# token1: active, unused, no expiration date
# token2: inactive, unused, no expiration date
# token3: used once, one-time, now inactive
# token4: active, used once, expiration date
# token5: inactive, expiration date
self.assertEqual(
test_token.name, test_tokens.get_token(test_token.name).name
)
self.assertEqual(
test_token2.name, test_tokens.get_token(test_token2.name).name
)
self.assertEqual(
test_token2.active(), test_tokens.get_token(test_token2.name).active()
)
self.assertEqual(
test_token3.used, test_tokens.get_token(test_token3.name).used
)
self.assertEqual(
test_token3.active(), test_tokens.get_token(test_token3.name).active()
)
self.assertEqual(
test_token4.used, test_tokens.get_token(test_token4.name).used
)
self.assertEqual(
test_token4.expiration_date,
test_tokens.get_token(test_token4.name).expiration_date,
)
self.assertEqual(
test_token5.active(), test_tokens.get_token(test_token5.name).active()
)
@parameterized.expand(
[
[None, False],
[datetime.fromisoformat("2100-01-12"), False],
[None, True],
[datetime.fromisoformat("2100-01-12"), True],
]
)
def test_tokens_new(self, expiration_date, max_usage):
with self.app.app_context():
test_tokens = matrix_registration.tokens.Tokens()
test_token = test_tokens.new(
expiration_date=expiration_date, max_usage=max_usage
)
self.assertIsNotNone(test_token)
if expiration_date:
self.assertIsNotNone(test_token.expiration_date)
else:
self.assertIsNone(test_token.expiration_date)
if max_usage:
self.assertTrue(test_token.max_usage)
else:
self.assertFalse(test_token.max_usage)
self.assertTrue(test_tokens.active(test_token.name))
@parameterized.expand(
[
[None, False, 10, True],
[datetime.fromisoformat("2100-01-12"), False, 10, True],
[None, True, 1, False],
[None, True, 0, True],
[datetime.fromisoformat("2100-01-12"), True, 1, False],
[datetime.fromisoformat("2100-01-12"), True, 2, False],
[datetime.fromisoformat("2100-01-12"), True, 0, True],
]
)
def test_tokens_active_form(self, expiration_date, max_usage, times_used, active):
with self.app.app_context():
test_tokens = matrix_registration.tokens.Tokens()
test_token = test_tokens.new(
expiration_date=expiration_date, max_usage=max_usage
)
for n in range(times_used):
test_tokens.use(test_token.name)
if not max_usage:
self.assertEqual(test_token.used, times_used)
elif times_used == 0:
self.assertEqual(test_token.used, 0)
else:
self.assertEqual(test_token.used, 1)
self.assertEqual(test_tokens.active(test_token.name), active)
@parameterized.expand(
[
[None, True],
[datetime.fromisoformat("2100-01-12"), False],
[datetime.fromisoformat("2200-01-13"), True],
]
)
def test_tokens_active(self, expiration_date, active):
with self.app.app_context():
test_tokens = matrix_registration.tokens.Tokens()
test_token = test_tokens.new(expiration_date=expiration_date)
self.assertEqual(test_tokens.active(test_token.name), True)
# date changed to after expiration date
with patch("matrix_registration.tokens.datetime") as mock_date:
mock_date.now.return_value = datetime.fromisoformat("2200-01-12")
self.assertEqual(test_tokens.active(test_token.name), active)
@parameterized.expand(
[
["DoubleWizardSky"],
["null"],
["false"],
]
)
def test_tokens_repr(self, name):
with self.app.app_context():
test_token1 = matrix_registration.tokens.Token(name=name)
self.assertEqual(str(test_token1), name)
def test_token_repr(self):
with self.app.app_context():
test_tokens = matrix_registration.tokens.Tokens()
test_token1 = test_tokens.new()
test_token2 = test_tokens.new()
test_token3 = test_tokens.new()
test_token4 = test_tokens.new()
test_token5 = test_tokens.new()
expected_answer = (
"%s, " % test_token1.name
+ "%s, " % test_token2.name
+ "%s, " % test_token3.name
+ "%s, " % test_token4.name
+ "%s" % test_token5.name
)
self.assertEqual(str(test_tokens), expected_answer)
class ApiTest(unittest.TestCase):
def setUp(self):
matrix_registration.config.config = Config(data=GOOD_CONFIG)
app = create_app(testing=True)
with app.app_context():
app.config.from_mapping(
SQLALCHEMY_DATABASE_URI=matrix_registration.config.config.db,
SQLALCHEMY_TRACK_MODIFICATIONS=False,
)
db.init_app(app)
db.create_all()
self.client = app.test_client()
self.app = app
def tearDown(self):
os.remove(matrix_registration.config.config.db[10:])
@parameterized.expand(
[
["test1", "test1234", "test1234", True, 200],
["", "test1234", "test1234", True, 400],
["test2", "", "test1234", True, 400],
["test3", "test1234", "", True, 400],
["test4", "test1234", "test1234", False, 400],
["@test5:matrix.org", "test1234", "test1234", True, 200],
["@test6:wronghs.org", "test1234", "test1234", True, 400],
["test7", "test1234", "tet1234", True, 400],
["teüst8", "test1234", "test1234", True, 400],
["@test9@matrix.org", "test1234", "test1234", True, 400],
["test11@matrix.org", "test1234", "test1234", True, 400],
["", "test1234", "test1234", True, 400],
[
"".join(random.choices(string.ascii_uppercase, k=256)),
"test1234",
"test1234",
True,
400,
],
["@admin:matrix.org", "test1234", "test1234", True, 400],
["matrixadmin123", "test1234", "test1234", True, 400],
]
)
# check form activeators
@patch("matrix_registration.matrix_api._get_nonce", side_effect=mocked__get_nonce)
@patch(
"matrix_registration.matrix_api.requests.post", side_effect=mocked_requests_post
)
def test_register(
self, username, password, confirm, token, status, mock_get, mock_nonce
):
matrix_registration.config.config = Config(data=GOOD_CONFIG)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(
expiration_date=None, max_usage=True
)
# replace matrix with in config set hs
domain = urlparse(
matrix_registration.config.config.server_location
).hostname
if username:
username = username.replace("matrix.org", domain)
if not token:
test_token.name = ""
rv = self.client.post(
"/register",
data=dict(
username=username,
password=password,
confirm=confirm,
token=test_token.name,
),
)
if rv.status_code == 200:
account_data = json.loads(rv.data.decode("utf8").replace("'", '"'))
# print(account_data)
self.assertEqual(rv.status_code, status)
@patch("matrix_registration.matrix_api._get_nonce", side_effect=mocked__get_nonce)
@patch(
"matrix_registration.matrix_api.requests.post", side_effect=mocked_requests_post
)
def test_register_wrong_hs(self, mock_get, mock_nonce):
matrix_registration.config.config = Config(data=BAD_CONFIG1)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(
expiration_date=None, max_usage=True
)
rv = self.client.post(
"/register",
data=dict(
username="username",
password="password",
confirm="password",
token=test_token.name,
),
)
self.assertEqual(rv.status_code, 500)
@patch("matrix_registration.matrix_api._get_nonce", side_effect=mocked__get_nonce)
@patch(
"matrix_registration.matrix_api.requests.post", side_effect=mocked_requests_post
)
def test_register_wrong_secret(self, mock_get, mock_nonce):
matrix_registration.config.config = Config(data=BAD_CONFIG3)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(
expiration_date=None, max_usage=True
)
rv = self.client.post(
"/register",
data=dict(
username="username",
password="password",
confirm="password",
token=test_token.name,
),
)
self.assertEqual(rv.status_code, 500)
def test_get_tokens(self):
matrix_registration.config.config = Config(data=GOOD_CONFIG)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(
expiration_date=None, max_usage=True
)
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
rv = self.client.get("/api/token", headers=headers)
self.assertEqual(rv.status_code, 200)
token_data = json.loads(rv.data.decode("utf8").replace("'", '"'))
self.assertEqual(token_data[0]["expiration_date"], None)
self.assertEqual(token_data[0]["max_usage"], True)
def test_error_get_tokens(self):
matrix_registration.config.config = Config(data=BAD_CONFIG2)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(
expiration_date=None, max_usage=True
)
secret = matrix_registration.config.config.admin_api_shared_secret
matrix_registration.config.config = Config(data=GOOD_CONFIG)
headers = {"Authorization": "SharedSecret %s" % secret}
rv = self.client.get("/api/token", headers=headers)
self.assertEqual(rv.status_code, 401)
token_data = json.loads(rv.data.decode("utf8").replace("'", '"'))
self.assertEqual(token_data["errcode"], "MR_BAD_SECRET")
self.assertEqual(token_data["error"], "wrong shared secret")
@parameterized.expand(
[
[None, True, None],
["2020-12-24", False, "2020-12-24 00:00:00"],
["2200-05-12", True, "2200-05-12 00:00:00"],
]
)
def test_post_token(self, expiration_date, max_usage, parsed_date):
matrix_registration.config.config = Config(data=GOOD_CONFIG)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(
expiration_date=None, max_usage=True
)
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
rv = self.client.post(
"/api/token",
data=json.dumps(
dict(expiration_date=expiration_date, max_usage=max_usage)
),
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 200)
token_data = json.loads(rv.data.decode("utf8").replace("'", '"'))
self.assertEqual(token_data["expiration_date"], parsed_date)
self.assertEqual(token_data["max_usage"], max_usage)
self.assertTrue(token_data["name"] is not None)
def test_error_post_token(self):
matrix_registration.config.config = Config(data=BAD_CONFIG2)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(
expiration_date=None, max_usage=True
)
secret = matrix_registration.config.config.admin_api_shared_secret
matrix_registration.config.config = Config(data=GOOD_CONFIG)
headers = {"Authorization": "SharedSecret %s" % secret}
rv = self.client.post(
"/api/token",
data=json.dumps(dict(expiration_date="24.12.2020", max_usage=False)),
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 401)
token_data = json.loads(rv.data.decode("utf8").replace("'", '"'))
self.assertEqual(token_data["errcode"], "MR_BAD_SECRET")
self.assertEqual(token_data["error"], "wrong shared secret")
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
rv = self.client.post(
"/api/token",
data=json.dumps(dict(expiration_date="2020-24-12", max_usage=False)),
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 400)
token_data = json.loads(rv.data.decode("utf8"))
self.assertEqual(token_data["errcode"], "MR_BAD_DATE_FORMAT")
self.assertEqual(token_data["error"], "date wasn't in YYYY-MM-DD format")
def test_patch_token(self):
matrix_registration.config.config = Config(data=GOOD_CONFIG)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(max_usage=True)
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
rv = self.client.patch(
"/api/token/" + test_token.name,
data=json.dumps(dict(disabled=True)),
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 200)
token_data = json.loads(rv.data.decode("utf8").replace("'", '"'))
self.assertEqual(token_data["active"], False)
self.assertEqual(token_data["max_usage"], True)
self.assertEqual(token_data["name"], test_token.name)
def test_error_patch_token(self):
matrix_registration.config.config = Config(data=BAD_CONFIG2)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(max_usage=True)
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
matrix_registration.config.config = Config(data=GOOD_CONFIG)
rv = self.client.patch(
"/api/token/" + test_token.name,
data=json.dumps(dict(disabled=True)),
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 401)
token_data = json.loads(rv.data.decode("utf8").replace("'", '"'))
self.assertEqual(token_data["errcode"], "MR_BAD_SECRET")
self.assertEqual(token_data["error"], "wrong shared secret")
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
rv = self.client.patch(
"/api/token/" + test_token.name,
data=json.dumps(dict(active=False)),
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 400)
token_data = json.loads(rv.data.decode("utf8"))
self.assertEqual(token_data["errcode"], "MR_BAD_USER_REQUEST")
self.assertEqual(
token_data["error"], "you're not allowed to change this property"
)
rv = self.client.patch(
"/api/token/" + "nicememe",
data=json.dumps(dict(disabled=True)),
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 404)
token_data = json.loads(rv.data.decode("utf8"))
self.assertEqual(token_data["errcode"], "MR_TOKEN_NOT_FOUND")
self.assertEqual(token_data["error"], "token does not exist")
def test_delete_token(self):
matrix_registration.config.config = Config(data=GOOD_CONFIG)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(max_usage=True)
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
rv = self.client.get(
"/api/token/" + test_token.name,
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 200)
rv = self.client.delete(
"/api/token/" + test_token.name,
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 200)
rv = self.client.get(
"/api/token/" + test_token.name,
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 404)
def test_error_delete_token(self):
matrix_registration.config.config = Config(data=BAD_CONFIG2)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(max_usage=True)
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
matrix_registration.config.config = Config(data=GOOD_CONFIG)
rv = self.client.delete(
"/api/token/" + test_token.name,
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 401)
token_data = json.loads(rv.data.decode("utf8").replace("'", '"'))
self.assertEqual(token_data["errcode"], "MR_BAD_SECRET")
self.assertEqual(token_data["error"], "wrong shared secret")
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
rv = self.client.delete(
"/api/token/" + "nicememe",
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 404)
token_data = json.loads(rv.data.decode("utf8"))
self.assertEqual(token_data["errcode"], "MR_TOKEN_NOT_FOUND")
self.assertEqual(token_data["error"], "token does not exist")
@parameterized.expand(
[
[None, True, None],
[datetime.fromisoformat("2020-12-24"), False, "2020-12-24 00:00:00"],
[datetime.fromisoformat("2200-05-12"), True, "2200-05-12 00:00:00"],
]
)
def test_get_token(self, expiration_date, max_usage, parsed_date):
matrix_registration.config.config = Config(data=BAD_CONFIG2)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(
expiration_date=expiration_date, max_usage=max_usage
)
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
rv = self.client.get(
"/api/token/" + test_token.name,
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 200)
token_data = json.loads(rv.data.decode("utf8"))
self.assertEqual(token_data["expiration_date"], parsed_date)
self.assertEqual(token_data["max_usage"], max_usage)
def test_error_get_token(self):
matrix_registration.config.config = Config(data=BAD_CONFIG2)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
test_token = matrix_registration.tokens.tokens.new(max_usage=True)
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
rv = self.client.get(
"/api/token/" + "nice_meme",
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 404)
token_data = json.loads(rv.data.decode("utf8"))
self.assertEqual(token_data["errcode"], "MR_TOKEN_NOT_FOUND")
self.assertEqual(token_data["error"], "token does not exist")
matrix_registration.config.config = Config(data=BAD_CONFIG2)
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
matrix_registration.config.config = Config(data=GOOD_CONFIG)
rv = self.client.patch(
"/api/token/" + test_token.name,
data=json.dumps(dict(disabled=True)),
content_type="application/json",
headers=headers,
)
self.assertEqual(rv.status_code, 401)
token_data = json.loads(rv.data.decode("utf8").replace("'", '"'))
self.assertEqual(token_data["errcode"], "MR_BAD_SECRET")
self.assertEqual(token_data["error"], "wrong shared secret")
def test_rate_limit_exempt(self):
matrix_registration.config.config = Config(data=GOOD_CONFIG)
with self.app.app_context():
matrix_registration.tokens.tokens = matrix_registration.tokens.Tokens()
secret = matrix_registration.config.config.admin_api_shared_secret
headers = {"Authorization": "SharedSecret %s" % secret}
for i in range(110):
self.client.get("/api/token", headers=headers)
rv = self.client.get("/api/token", headers=headers)
self.assertEqual(rv.status_code, 429)
for i in range(110):
self.client.get("/health")
rv = self.client.get("/health")
self.assertEqual(rv.status_code, 200)
class ConfigTest(unittest.TestCase):
def test_config_update(self):
matrix_registration.config.config = Config(data=GOOD_CONFIG)
self.assertEqual(matrix_registration.config.config.port, GOOD_CONFIG["port"])
self.assertEqual(
matrix_registration.config.config.server_location,
GOOD_CONFIG["server_location"],
)
matrix_registration.config.config.update(BAD_CONFIG1)
self.assertEqual(matrix_registration.config.config.port, BAD_CONFIG1["port"])
self.assertEqual(
matrix_registration.config.config.server_location,
BAD_CONFIG1["server_location"],
)
def test_config_path(self):
# BAD_CONFIG1_path = "x"
good_config_path = "tests/test_config.yaml"
with open(good_config_path, "w") as outfile:
yaml.dump(GOOD_CONFIG, outfile, default_flow_style=False)
matrix_registration.config.config = Config(path=good_config_path)
self.assertIsNotNone(matrix_registration.config.config)
os.remove(good_config_path)
class CliTest(unittest.TestCase):
path = "tests/test_config.yaml"
db = "tests/db.sqlite"
def setUp(self):
try:
os.remove(self.db)
except FileNotFoundError:
pass
with open(self.path, "w") as outfile:
yaml.dump(GOOD_CONFIG, outfile, default_flow_style=False)
def tearDown(self):
os.remove(self.path)
os.remove(self.db)
def test_create_token(self):
runner = create_app().test_cli_runner()
generate = runner.invoke(cli, ["--config-path", self.path, "generate", "-m", 1])
name1 = generate.output.strip()
status = runner.invoke(cli, ["--config-path", self.path, "status", "-s", name1])
valid, info_dict_string = status.output.strip().split("\n", 1)
self.assertEqual(valid, "This token is valid")
comparison_dict = {
"name": name1,
"used": 0,
"expiration_date": None,
"max_usage": 1,
"disabled": False,
"ips": [],
"active": True,
}
self.assertEqual(json.loads(info_dict_string), comparison_dict)
runner.invoke(cli, ["--config-path", self.path, "status", "-d", name1])
status = runner.invoke(cli, ["--config-path", self.path, "status", "-s", name1])
valid, info_dict_string = status.output.strip().split("\n", 1)
self.assertEqual(valid, "This token is not valid")
comparison_dict = {
"name": name1,
"used": 0,
"expiration_date": None,
"max_usage": 1,
"disabled": True,
"ips": [],
"active": False,
}
self.assertEqual(json.loads(info_dict_string), comparison_dict)
generate = runner.invoke(
cli, ["--config-path", self.path, "generate", "-e", "2220-05-12"]
)
name2 = generate.output.strip()
status = runner.invoke(cli, ["--config-path", self.path, "status", "-s", name2])
valid, info_dict_string = status.output.strip().split("\n", 1)
self.assertEqual(valid, "This token is valid")
comparison_dict = {
"name": name2,
"used": 0,
"expiration_date": "2220-05-12 00:00:00",
"max_usage": 0,
"disabled": False,
"ips": [],
"active": True,
}
self.assertEqual(json.loads(info_dict_string), comparison_dict)
status = runner.invoke(cli, ["--config-path", self.path, "status", "-l"])
list = status.output.strip()
self.assertEqual(list, f"{name1}, {name2}")
if "logging" in sys.argv:
logging.basicConfig(level=logging.DEBUG)
if __name__ == "__main__":
unittest.main()