matrix-python-sdk/test/crypto/olm_device_test.py

207 lines
8.4 KiB
Python

import pytest
pytest.importorskip("olm") # noqa
import json
from copy import deepcopy
import responses
from matrix_client.api import MATRIX_V2_API_PATH
from matrix_client.client import MatrixClient
from matrix_client.crypto.olm_device import OlmDevice
from test.response_examples import example_key_upload_response
HOSTNAME = 'http://example.com'
class TestOlmDevice:
cli = MatrixClient(HOSTNAME)
user_id = '@user:matrix.org'
device_id = 'QBUAZIFURK'
device = OlmDevice(cli.api, user_id, device_id)
signing_key = device.olm_account.identity_keys['ed25519']
def test_sign_json(self):
example_payload = {
"name": "example.org",
"unsigned": {
"age_ts": 922834800000
}
}
saved_payload = deepcopy(example_payload)
signed_payload = self.device.sign_json(example_payload)
signature = signed_payload.pop('signatures')
# We should not have modified the payload besides the signatures key
assert example_payload == saved_payload
key_id = 'ed25519:' + self.device_id
assert signature[self.user_id][key_id]
def test_verify_json(self):
example_payload = {
"test": "test",
"unsigned": {
"age_ts": 922834800000
},
"signatures": {
"@user:matrix.org": {
"ed25519:QBUAZIFURK": ("WI7TgwqTp4YVn1dFWmDu7xrJvEikEzAbmoqyM5JY5t0P"
"6fVaiMFAirmwb13GzIyYDLR+nQfoksNBcrp7xSaMCA")
}
}
}
saved_payload = deepcopy(example_payload)
signing_key = "WQF5z9b4DV1DANI5HUMJfhTIDvJs1jkoGTLY6AQdjF0"
assert self.device.verify_json(example_payload, signing_key, self.user_id,
self.device_id)
# We should not have modified the payload
assert example_payload == saved_payload
# Try to verify an object that has been tampered with
example_payload['test'] = 'test1'
assert not self.device.verify_json(example_payload, signing_key, self.user_id,
self.device_id)
# Try to verify invalid payloads
example_payload['signatures'].pop(self.user_id)
assert not self.device.verify_json(example_payload, signing_key, self.user_id,
self.device_id)
example_payload.pop('signatures')
assert not self.device.verify_json(example_payload, signing_key, self.user_id,
self.device_id)
def test_sign_verify(self):
example_payload = {
"name": "example.org",
}
signed_payload = self.device.sign_json(example_payload)
assert self.device.verify_json(signed_payload, self.signing_key, self.user_id,
self.device_id)
@responses.activate
def test_upload_identity_keys(self):
upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload'
self.device.one_time_keys_manager.server_counts = {}
resp = deepcopy(example_key_upload_response)
responses.add(responses.POST, upload_url, json=resp)
assert self.device.upload_identity_keys() is None
assert self.device.one_time_keys_manager.server_counts == \
resp['one_time_key_counts']
req_device_keys = json.loads(responses.calls[0].request.body)['device_keys']
assert req_device_keys['user_id'] == self.user_id
assert req_device_keys['device_id'] == self.device_id
assert req_device_keys['algorithms'] == self.device._algorithms
assert 'keys' in req_device_keys
assert 'signatures' in req_device_keys
assert self.device.verify_json(req_device_keys, self.signing_key, self.user_id,
self.device_id)
@pytest.mark.parametrize('proportion', [-1, 2])
def test_upload_identity_keys_invalid(self, proportion):
with pytest.raises(ValueError):
OlmDevice(self.cli.api,
self.user_id,
self.device_id,
signed_keys_proportion=proportion)
@responses.activate
@pytest.mark.parametrize('proportion', [0, 1, 0.5, 0.33])
def test_upload_one_time_keys(self, proportion):
upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload'
resp = deepcopy(example_key_upload_response)
counts = resp['one_time_key_counts']
counts['curve25519'] = counts['signed_curve25519'] = 10
responses.add(responses.POST, upload_url, json=resp)
device = OlmDevice(
self.cli.api, self.user_id, self.device_id, signed_keys_proportion=proportion)
assert not device.one_time_keys_manager.server_counts
max_keys = device.olm_account.max_one_time_keys // 2
signed_keys_to_upload = \
max(round(max_keys * proportion) - counts['signed_curve25519'], 0)
unsigned_keys_to_upload = \
max(round(max_keys * (1 - proportion)) - counts['curve25519'], 0)
expected_return = {}
if signed_keys_to_upload:
expected_return['signed_curve25519'] = signed_keys_to_upload
if unsigned_keys_to_upload:
expected_return['curve25519'] = unsigned_keys_to_upload
assert device.upload_one_time_keys() == expected_return
assert len(responses.calls) == 2
assert device.one_time_keys_manager.server_counts == resp['one_time_key_counts']
req_otk = json.loads(responses.calls[1].request.body)['one_time_keys']
assert len(req_otk) == unsigned_keys_to_upload + signed_keys_to_upload
assert len([key for key in req_otk if not key.startswith('signed')]) == \
unsigned_keys_to_upload
assert len([key for key in req_otk if key.startswith('signed')]) == \
signed_keys_to_upload
for k in req_otk:
if k == 'signed_curve25519':
device.verify_json(req_otk[k], device.signing_key, device.user_id,
device.device_id)
@responses.activate
def test_upload_one_time_keys_enough(self):
upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload'
self.device.one_time_keys_manager.server_counts = {}
limit = self.device.olm_account.max_one_time_keys // 2
resp = {'one_time_key_counts': {'signed_curve25519': limit}}
responses.add(responses.POST, upload_url, json=resp)
assert not self.device.upload_one_time_keys()
@responses.activate
def test_upload_one_time_keys_force_update(self):
upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload'
self.device.one_time_keys_manager.server_counts = {'curve25519': 10}
resp = deepcopy(example_key_upload_response)
responses.add(responses.POST, upload_url, json=resp)
self.device.upload_one_time_keys()
assert len(responses.calls) == 1
self.device.upload_one_time_keys(force_update=True)
assert len(responses.calls) == 3
@responses.activate
@pytest.mark.parametrize('count,should_upload', [(0, True), (25, False), (4, True)])
def test_update_one_time_key_counts(self, count, should_upload):
upload_url = HOSTNAME + MATRIX_V2_API_PATH + '/keys/upload'
responses.add(responses.POST, upload_url, json={'one_time_key_counts': {}})
self.device.one_time_keys_manager.target_counts['signed_curve25519'] = 50
self.device.one_time_keys_manager.server_counts.clear()
count_dict = {}
if count:
count_dict['signed_curve25519'] = count
self.device.update_one_time_key_counts(count_dict)
if should_upload:
if count:
req_otk = json.loads(responses.calls[0].request.body)['one_time_keys']
assert len(responses.calls) == 1
else:
req_otk = json.loads(responses.calls[1].request.body)['one_time_keys']
assert len(responses.calls) == 2
assert len(req_otk) == 50 - count
else:
assert not len(responses.calls)
@pytest.mark.parametrize('threshold', [-1, 2])
def test_invalid_keys_threshold(self, threshold):
with pytest.raises(ValueError):
OlmDevice(self.cli.api,
self.user_id,
self.device_id,
keys_threshold=threshold)