mirror of https://github.com/zeromq/pyzmq.git
417 lines
14 KiB
Python
417 lines
14 KiB
Python
# Copyright (C) PyZMQ Developers
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import warnings
|
|
from contextlib import contextmanager
|
|
|
|
import pytest
|
|
|
|
import zmq
|
|
import zmq.asyncio
|
|
import zmq.auth
|
|
from zmq.tests import SkipTest, skip_pypy
|
|
|
|
try:
|
|
import tornado
|
|
except ImportError:
|
|
tornado = None
|
|
|
|
|
|
@pytest.fixture
|
|
def Context(event_loop):
|
|
return zmq.asyncio.Context
|
|
|
|
|
|
@pytest.fixture
|
|
def create_certs(tmpdir):
|
|
"""Create CURVE certificates for a test"""
|
|
|
|
# Create temporary CURVE key pairs for this test run. We create all keys in a
|
|
# temp directory and then move them into the appropriate private or public
|
|
# directory.
|
|
base_dir = str(tmpdir.join("certs").mkdir())
|
|
keys_dir = os.path.join(base_dir, "certificates")
|
|
public_keys_dir = os.path.join(base_dir, 'public_keys')
|
|
secret_keys_dir = os.path.join(base_dir, 'private_keys')
|
|
|
|
os.mkdir(keys_dir)
|
|
os.mkdir(public_keys_dir)
|
|
os.mkdir(secret_keys_dir)
|
|
|
|
server_public_file, server_secret_file = zmq.auth.create_certificates(
|
|
keys_dir, "server"
|
|
)
|
|
client_public_file, client_secret_file = zmq.auth.create_certificates(
|
|
keys_dir, "client"
|
|
)
|
|
for key_file in os.listdir(keys_dir):
|
|
if key_file.endswith(".key"):
|
|
shutil.move(
|
|
os.path.join(keys_dir, key_file), os.path.join(public_keys_dir, '.')
|
|
)
|
|
|
|
for key_file in os.listdir(keys_dir):
|
|
if key_file.endswith(".key_secret"):
|
|
shutil.move(
|
|
os.path.join(keys_dir, key_file), os.path.join(secret_keys_dir, '.')
|
|
)
|
|
|
|
return (public_keys_dir, secret_keys_dir)
|
|
|
|
|
|
def load_certs(secret_keys_dir):
|
|
"""Return server and client certificate keys"""
|
|
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
|
|
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
|
|
|
|
server_public, server_secret = zmq.auth.load_certificate(server_secret_file)
|
|
client_public, client_secret = zmq.auth.load_certificate(client_secret_file)
|
|
|
|
return server_public, server_secret, client_public, client_secret
|
|
|
|
|
|
@pytest.fixture
|
|
def public_keys_dir(create_certs):
|
|
public_keys_dir, secret_keys_dir = create_certs
|
|
return public_keys_dir
|
|
|
|
|
|
@pytest.fixture
|
|
def secret_keys_dir(create_certs):
|
|
public_keys_dir, secret_keys_dir = create_certs
|
|
return secret_keys_dir
|
|
|
|
|
|
@pytest.fixture
|
|
def certs(secret_keys_dir):
|
|
return load_certs(secret_keys_dir)
|
|
|
|
|
|
@pytest.fixture
|
|
async def _async_setup(request, event_loop):
|
|
"""pytest doesn't support async setup/teardown"""
|
|
instance = request.instance
|
|
await instance.async_setup()
|
|
yield
|
|
# make sure our teardown runs before the loop closes
|
|
instance.async_teardown()
|
|
|
|
|
|
@pytest.mark.usefixtures("_async_setup")
|
|
class AuthTest:
|
|
auth = None
|
|
|
|
async def async_setup(self):
|
|
self.context = zmq.asyncio.Context()
|
|
if zmq.zmq_version_info() < (4, 0):
|
|
raise SkipTest("security is new in libzmq 4.0")
|
|
try:
|
|
zmq.curve_keypair()
|
|
except zmq.ZMQError:
|
|
raise SkipTest("security requires libzmq to have curve support")
|
|
# enable debug logging while we run tests
|
|
logging.getLogger('zmq.auth').setLevel(logging.DEBUG)
|
|
self.auth = self.make_auth()
|
|
await self.start_auth()
|
|
|
|
def async_teardown(self):
|
|
if self.auth:
|
|
self.auth.stop()
|
|
self.auth = None
|
|
self.context.term()
|
|
|
|
def make_auth(self):
|
|
raise NotImplementedError()
|
|
|
|
async def start_auth(self):
|
|
self.auth.start()
|
|
|
|
async def can_connect(self, server, client, timeout=1000):
|
|
"""Check if client can connect to server using tcp transport"""
|
|
result = False
|
|
iface = 'tcp://127.0.0.1'
|
|
port = server.bind_to_random_port(iface)
|
|
client.connect("%s:%i" % (iface, port))
|
|
msg = [b"Hello World"]
|
|
# run poll on server twice
|
|
# to flush spurious events
|
|
await server.poll(100, zmq.POLLOUT)
|
|
|
|
if await server.poll(timeout, zmq.POLLOUT):
|
|
try:
|
|
await server.send_multipart(msg, zmq.NOBLOCK)
|
|
except zmq.Again:
|
|
warnings.warn("server set POLLOUT, but cannot send", RuntimeWarning)
|
|
return False
|
|
else:
|
|
return False
|
|
|
|
if await client.poll(timeout):
|
|
try:
|
|
rcvd_msg = await client.recv_multipart(zmq.NOBLOCK)
|
|
except zmq.Again:
|
|
warnings.warn("client set POLLIN, but cannot recv", RuntimeWarning)
|
|
else:
|
|
assert rcvd_msg == msg
|
|
result = True
|
|
return result
|
|
|
|
@contextmanager
|
|
def push_pull(self):
|
|
with self.context.socket(zmq.PUSH) as server, self.context.socket(
|
|
zmq.PULL
|
|
) as client:
|
|
server.linger = 0
|
|
server.sndtimeo = 2000
|
|
client.linger = 0
|
|
client.rcvtimeo = 2000
|
|
yield server, client
|
|
|
|
@contextmanager
|
|
def curve_push_pull(self, certs, client_key="ok"):
|
|
server_public, server_secret, client_public, client_secret = certs
|
|
with self.push_pull() as (server, client):
|
|
server.curve_publickey = server_public
|
|
server.curve_secretkey = server_secret
|
|
server.curve_server = True
|
|
if client_key is not None:
|
|
client.curve_publickey = client_public
|
|
client.curve_secretkey = client_secret
|
|
if client_key == "ok":
|
|
client.curve_serverkey = server_public
|
|
else:
|
|
private, public = zmq.curve_keypair()
|
|
client.curve_serverkey = public
|
|
yield (server, client)
|
|
|
|
async def test_null(self):
|
|
"""threaded auth - NULL"""
|
|
# A default NULL connection should always succeed, and not
|
|
# go through our authentication infrastructure at all.
|
|
self.auth.stop()
|
|
self.auth = None
|
|
|
|
# use a new context, so ZAP isn't inherited
|
|
self.context.term()
|
|
self.context = zmq.asyncio.Context()
|
|
|
|
with self.push_pull() as (server, client):
|
|
assert await self.can_connect(server, client)
|
|
|
|
# By setting a domain we switch on authentication for NULL sockets,
|
|
# though no policies are configured yet. The client connection
|
|
# should still be allowed.
|
|
with self.push_pull() as (server, client):
|
|
server.zap_domain = b'global'
|
|
assert await self.can_connect(server, client)
|
|
|
|
async def test_deny(self):
|
|
# deny 127.0.0.1, connection should fail
|
|
self.auth.deny('127.0.0.1')
|
|
with pytest.raises(ValueError):
|
|
self.auth.allow("127.0.0.2")
|
|
with self.push_pull() as (server, client):
|
|
# By setting a domain we switch on authentication for NULL sockets,
|
|
# though no policies are configured yet.
|
|
server.zap_domain = b'global'
|
|
assert not await self.can_connect(server, client, timeout=100)
|
|
|
|
async def test_allow(self):
|
|
# allow 127.0.0.1, connection should pass
|
|
self.auth.allow('127.0.0.1')
|
|
with pytest.raises(ValueError):
|
|
self.auth.deny("127.0.0.2")
|
|
with self.push_pull() as (server, client):
|
|
# By setting a domain we switch on authentication for NULL sockets,
|
|
# though no policies are configured yet.
|
|
server.zap_domain = b'global'
|
|
assert await self.can_connect(server, client)
|
|
|
|
@pytest.mark.parametrize(
|
|
"enabled, password, success",
|
|
[
|
|
(True, "correct", True),
|
|
(False, "correct", False),
|
|
(True, "incorrect", False),
|
|
],
|
|
)
|
|
async def test_plain(self, enabled, password, success):
|
|
"""threaded auth - PLAIN"""
|
|
|
|
# Try PLAIN authentication - without configuring server, connection should fail
|
|
with self.push_pull() as (server, client):
|
|
server.plain_server = True
|
|
if password:
|
|
client.plain_username = b'admin'
|
|
client.plain_password = password.encode("ascii")
|
|
if enabled:
|
|
self.auth.configure_plain(domain='*', passwords={'admin': 'correct'})
|
|
if success:
|
|
assert await self.can_connect(server, client)
|
|
else:
|
|
assert not await self.can_connect(server, client, timeout=100)
|
|
|
|
# Remove authenticator and check that a normal connection works
|
|
self.auth.stop()
|
|
self.auth = None
|
|
with self.push_pull() as (server, client):
|
|
assert await self.can_connect(server, client)
|
|
|
|
@pytest.mark.parametrize(
|
|
"client_key, location, success",
|
|
[
|
|
('ok', zmq.auth.CURVE_ALLOW_ANY, True),
|
|
('ok', "public_keys", True),
|
|
('bad', "public_keys", False),
|
|
(None, "public_keys", False),
|
|
],
|
|
)
|
|
async def test_curve(self, certs, public_keys_dir, client_key, location, success):
|
|
"""threaded auth - CURVE"""
|
|
self.auth.allow('127.0.0.1')
|
|
|
|
# Try CURVE authentication - without configuring server, connection should fail
|
|
with self.curve_push_pull(certs, client_key=client_key) as (server, client):
|
|
if location:
|
|
if location == 'public_keys':
|
|
location = public_keys_dir
|
|
self.auth.configure_curve(domain='*', location=location)
|
|
if success:
|
|
assert await self.can_connect(server, client, timeout=100)
|
|
else:
|
|
assert not await self.can_connect(server, client, timeout=100)
|
|
|
|
# Remove authenticator and check that a normal connection works
|
|
self.auth.stop()
|
|
self.auth = None
|
|
|
|
# Try connecting using NULL and no authentication enabled, connection should pass
|
|
with self.push_pull() as (server, client):
|
|
assert await self.can_connect(server, client)
|
|
|
|
@pytest.mark.parametrize("key", ["ok", "wrong"])
|
|
@pytest.mark.parametrize("async_", [True, False])
|
|
async def test_curve_callback(self, certs, key, async_):
|
|
"""threaded auth - CURVE with callback authentication"""
|
|
self.auth.allow('127.0.0.1')
|
|
server_public, server_secret, client_public, client_secret = certs
|
|
|
|
class CredentialsProvider:
|
|
def __init__(self):
|
|
if key == "ok":
|
|
self.client = client_public
|
|
else:
|
|
self.client = server_public
|
|
|
|
def callback(self, domain, key):
|
|
if key == self.client:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
async def async_callback(self, domain, key):
|
|
await asyncio.sleep(0.1)
|
|
if key == self.client:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
if async_:
|
|
CredentialsProvider.callback = CredentialsProvider.async_callback
|
|
provider = CredentialsProvider()
|
|
self.auth.configure_curve_callback(credentials_provider=provider)
|
|
with self.curve_push_pull(certs) as (server, client):
|
|
if key == "ok":
|
|
assert await self.can_connect(server, client)
|
|
else:
|
|
assert not await self.can_connect(server, client, timeout=200)
|
|
|
|
@skip_pypy
|
|
async def test_curve_user_id(self, certs, public_keys_dir):
|
|
"""threaded auth - CURVE"""
|
|
self.auth.allow('127.0.0.1')
|
|
server_public, server_secret, client_public, client_secret = certs
|
|
self.auth.configure_curve(domain='*', location=public_keys_dir)
|
|
# reverse server-client relationship, so server is PULL
|
|
with self.push_pull() as (client, server):
|
|
server.curve_publickey = server_public
|
|
server.curve_secretkey = server_secret
|
|
server.curve_server = True
|
|
|
|
client.curve_publickey = client_public
|
|
client.curve_secretkey = client_secret
|
|
client.curve_serverkey = server_public
|
|
|
|
assert await self.can_connect(client, server)
|
|
|
|
# test default user-id map
|
|
await client.send(b'test')
|
|
msg = await server.recv(copy=False)
|
|
assert msg.bytes == b'test'
|
|
try:
|
|
user_id = msg.get('User-Id')
|
|
except zmq.ZMQVersionError:
|
|
pass
|
|
else:
|
|
assert user_id == client_public.decode("utf8")
|
|
|
|
# test custom user-id map
|
|
self.auth.curve_user_id = lambda client_key: 'custom'
|
|
|
|
with self.context.socket(zmq.PUSH) as client2:
|
|
client2.curve_publickey = client_public
|
|
client2.curve_secretkey = client_secret
|
|
client2.curve_serverkey = server_public
|
|
assert await self.can_connect(client2, server)
|
|
|
|
await client2.send(b'test2')
|
|
msg = await server.recv(copy=False)
|
|
assert msg.bytes == b'test2'
|
|
try:
|
|
user_id = msg.get('User-Id')
|
|
except zmq.ZMQVersionError:
|
|
pass
|
|
else:
|
|
assert user_id == 'custom'
|
|
|
|
|
|
class TestThreadAuthentication(AuthTest):
|
|
"""Test authentication running in a thread"""
|
|
|
|
def make_auth(self):
|
|
from zmq.auth.thread import ThreadAuthenticator
|
|
|
|
return ThreadAuthenticator(self.context)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
sys.platform == 'win32' and sys.version_info < (3, 7),
|
|
reason="flaky event loop cleanup on windows+py36",
|
|
)
|
|
class TestAsyncioAuthentication(AuthTest):
|
|
"""Test authentication running in a thread"""
|
|
|
|
def make_auth(self):
|
|
from zmq.auth.asyncio import AsyncioAuthenticator
|
|
|
|
return AsyncioAuthenticator(self.context)
|
|
|
|
|
|
async def test_ioloop_authenticator(context, event_loop, io_loop):
|
|
from tornado.ioloop import IOLoop
|
|
|
|
with warnings.catch_warnings():
|
|
from zmq.auth.ioloop import IOLoopAuthenticator
|
|
|
|
auth = IOLoopAuthenticator(context)
|
|
assert auth.context is context
|
|
|
|
loop = IOLoop(make_current=False)
|
|
with pytest.warns(DeprecationWarning):
|
|
auth = IOLoopAuthenticator(io_loop=loop)
|