mirror of https://github.com/zeromq/pyzmq.git
216 lines
5.4 KiB
Python
216 lines
5.4 KiB
Python
"""pytest configuration and fixtures"""
|
|
|
|
import asyncio
|
|
import inspect
|
|
import os
|
|
import signal
|
|
import time
|
|
from functools import partial
|
|
from threading import Thread
|
|
|
|
try:
|
|
import tornado
|
|
from tornado import version_info
|
|
except ImportError:
|
|
tornado = None
|
|
else:
|
|
if version_info < (5,):
|
|
tornado = None
|
|
from tornado.ioloop import IOLoop
|
|
|
|
import pytest
|
|
|
|
import zmq
|
|
import zmq.asyncio
|
|
|
|
test_timeout_seconds = os.environ.get("ZMQ_TEST_TIMEOUT")
|
|
teardown_timeout = 10
|
|
|
|
|
|
def pytest_collection_modifyitems(items):
|
|
"""This function is automatically run by pytest passing all collected test
|
|
functions.
|
|
We use it to add asyncio marker to all async tests and assert we don't use
|
|
test functions that are async generators which wouldn't make sense.
|
|
It is no longer required with pytest-asyncio >= 0.17
|
|
"""
|
|
for item in items:
|
|
if inspect.iscoroutinefunction(item.obj):
|
|
item.add_marker('asyncio')
|
|
assert not inspect.isasyncgenfunction(item.obj)
|
|
|
|
|
|
@pytest.fixture
|
|
async def io_loop(event_loop, request):
|
|
"""Create tornado io_loop on current asyncio event loop"""
|
|
if tornado is None:
|
|
pytest.skip()
|
|
io_loop = IOLoop.current()
|
|
assert asyncio.get_event_loop() is event_loop
|
|
assert io_loop.asyncio_loop is event_loop
|
|
|
|
def _close():
|
|
io_loop.close(all_fds=True)
|
|
|
|
request.addfinalizer(_close)
|
|
return io_loop
|
|
|
|
|
|
def term_context(ctx, timeout):
|
|
"""Terminate a context with a timeout"""
|
|
t = Thread(target=ctx.term)
|
|
t.daemon = True
|
|
t.start()
|
|
t.join(timeout=timeout)
|
|
if t.is_alive():
|
|
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
|
|
zmq.sugar.context.Context._instance = None
|
|
raise RuntimeError(
|
|
f"context {ctx} could not terminate, open sockets likely remain in test"
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def event_loop():
|
|
loop = asyncio.new_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
# make sure selectors are cleared
|
|
assert dict(zmq.asyncio._selectors) == {}
|
|
|
|
|
|
@pytest.fixture
|
|
def sigalrm_timeout():
|
|
"""Set timeout using SIGALRM
|
|
|
|
Avoids infinite hang in context.term for an unclean context,
|
|
raising an error instead.
|
|
"""
|
|
if not hasattr(signal, "SIGALRM") or not test_timeout_seconds:
|
|
return
|
|
|
|
def _alarm_timeout(*args):
|
|
raise TimeoutError(f"Test did not complete in {test_timeout_seconds} seconds")
|
|
|
|
signal.signal(signal.SIGALRM, _alarm_timeout)
|
|
signal.alarm(test_timeout_seconds)
|
|
|
|
|
|
@pytest.fixture
|
|
def Context():
|
|
"""Context class fixture
|
|
|
|
Override in modules to specify a different class (e.g. zmq.green)
|
|
"""
|
|
return zmq.Context
|
|
|
|
|
|
@pytest.fixture
|
|
def contexts(sigalrm_timeout):
|
|
"""Fixture to track contexts used in tests
|
|
|
|
For cleanup purposes
|
|
"""
|
|
contexts = set()
|
|
yield contexts
|
|
for ctx in contexts:
|
|
try:
|
|
term_context(ctx, teardown_timeout)
|
|
except Exception:
|
|
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
|
|
zmq.sugar.context.Context._instance = None
|
|
raise
|
|
|
|
|
|
@pytest.fixture
|
|
def context(Context, contexts):
|
|
"""Fixture for shared context"""
|
|
ctx = Context()
|
|
contexts.add(ctx)
|
|
return ctx
|
|
|
|
|
|
@pytest.fixture
|
|
def sockets(contexts):
|
|
sockets = []
|
|
yield sockets
|
|
# ensure any tracked sockets get their contexts cleaned up
|
|
for socket in sockets:
|
|
contexts.add(socket.context)
|
|
|
|
# close sockets
|
|
for socket in sockets:
|
|
socket.close(linger=0)
|
|
|
|
|
|
@pytest.fixture
|
|
def socket(context, sockets):
|
|
"""Fixture to create sockets, while tracking them for cleanup"""
|
|
|
|
def new_socket(*args, **kwargs):
|
|
s = context.socket(*args, **kwargs)
|
|
sockets.append(s)
|
|
return s
|
|
|
|
return new_socket
|
|
|
|
|
|
def assert_raises_errno(errno):
|
|
try:
|
|
yield
|
|
except zmq.ZMQError as e:
|
|
assert (
|
|
e.errno == errno
|
|
), f"wrong error raised, expected {zmq.ZMQError(errno)} got {zmq.ZMQError(e.errno)}"
|
|
else:
|
|
pytest.fail(f"Expected {zmq.ZMQError(errno)}, no error raised")
|
|
|
|
|
|
def recv(socket, *, timeout=5, flags=0, multipart=False, **kwargs):
|
|
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
|
|
if zmq.zmq_version_info() >= (3, 1, 0):
|
|
# zmq 3.1 has a bug, where poll can return false positives,
|
|
# so we wait a little bit just in case
|
|
# See LIBZMQ-280 on JIRA
|
|
time.sleep(0.1)
|
|
|
|
r, w, x = zmq.select([socket], [], [], timeout=timeout)
|
|
assert r, "Should have received a message"
|
|
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
|
|
|
|
recv = socket.recv_multipart if multipart else socket.recv
|
|
return recv(flags=flags, **kwargs)
|
|
|
|
|
|
recv_multipart = partial(recv, multipart=True)
|
|
|
|
|
|
@pytest.fixture
|
|
def create_bound_pair(socket):
|
|
def create_bound_pair(type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'):
|
|
"""Create a bound socket pair using a random port."""
|
|
s1 = socket(type1)
|
|
s1.linger = 0
|
|
port = s1.bind_to_random_port(interface)
|
|
s2 = socket(type2)
|
|
s2.linger = 0
|
|
s2.connect(f'{interface}:{port}')
|
|
return s1, s2
|
|
|
|
return create_bound_pair
|
|
|
|
|
|
@pytest.fixture
|
|
def bound_pair(create_bound_pair):
|
|
return create_bound_pair()
|
|
|
|
|
|
@pytest.fixture
|
|
def push_pull(create_bound_pair):
|
|
return create_bound_pair(zmq.PUSH, zmq.PULL)
|
|
|
|
|
|
@pytest.fixture
|
|
def dealer_router(create_bound_pair):
|
|
return create_bound_pair(zmq.DEALER, zmq.ROUTER)
|