pyzmq/tests/conftest.py

216 lines
5.4 KiB
Python

"""pytest configuration and fixtures"""
import asyncio
import inspect
import os
import signal
import time
from functools import partial
from threading import Thread
try:
import tornado
from tornado import version_info
except ImportError:
tornado = None
else:
if version_info < (5,):
tornado = None
from tornado.ioloop import IOLoop
import pytest
import zmq
import zmq.asyncio
test_timeout_seconds = os.environ.get("ZMQ_TEST_TIMEOUT")
teardown_timeout = 10
def pytest_collection_modifyitems(items):
"""This function is automatically run by pytest passing all collected test
functions.
We use it to add asyncio marker to all async tests and assert we don't use
test functions that are async generators which wouldn't make sense.
It is no longer required with pytest-asyncio >= 0.17
"""
for item in items:
if inspect.iscoroutinefunction(item.obj):
item.add_marker('asyncio')
assert not inspect.isasyncgenfunction(item.obj)
@pytest.fixture
async def io_loop(event_loop, request):
"""Create tornado io_loop on current asyncio event loop"""
if tornado is None:
pytest.skip()
io_loop = IOLoop.current()
assert asyncio.get_event_loop() is event_loop
assert io_loop.asyncio_loop is event_loop
def _close():
io_loop.close(all_fds=True)
request.addfinalizer(_close)
return io_loop
def term_context(ctx, timeout):
"""Terminate a context with a timeout"""
t = Thread(target=ctx.term)
t.daemon = True
t.start()
t.join(timeout=timeout)
if t.is_alive():
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
zmq.sugar.context.Context._instance = None
raise RuntimeError(
f"context {ctx} could not terminate, open sockets likely remain in test"
)
@pytest.fixture
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()
# make sure selectors are cleared
assert dict(zmq.asyncio._selectors) == {}
@pytest.fixture
def sigalrm_timeout():
"""Set timeout using SIGALRM
Avoids infinite hang in context.term for an unclean context,
raising an error instead.
"""
if not hasattr(signal, "SIGALRM") or not test_timeout_seconds:
return
def _alarm_timeout(*args):
raise TimeoutError(f"Test did not complete in {test_timeout_seconds} seconds")
signal.signal(signal.SIGALRM, _alarm_timeout)
signal.alarm(test_timeout_seconds)
@pytest.fixture
def Context():
"""Context class fixture
Override in modules to specify a different class (e.g. zmq.green)
"""
return zmq.Context
@pytest.fixture
def contexts(sigalrm_timeout):
"""Fixture to track contexts used in tests
For cleanup purposes
"""
contexts = set()
yield contexts
for ctx in contexts:
try:
term_context(ctx, teardown_timeout)
except Exception:
# reset Context.instance, so the failure to term doesn't corrupt subsequent tests
zmq.sugar.context.Context._instance = None
raise
@pytest.fixture
def context(Context, contexts):
"""Fixture for shared context"""
ctx = Context()
contexts.add(ctx)
return ctx
@pytest.fixture
def sockets(contexts):
sockets = []
yield sockets
# ensure any tracked sockets get their contexts cleaned up
for socket in sockets:
contexts.add(socket.context)
# close sockets
for socket in sockets:
socket.close(linger=0)
@pytest.fixture
def socket(context, sockets):
"""Fixture to create sockets, while tracking them for cleanup"""
def new_socket(*args, **kwargs):
s = context.socket(*args, **kwargs)
sockets.append(s)
return s
return new_socket
def assert_raises_errno(errno):
try:
yield
except zmq.ZMQError as e:
assert (
e.errno == errno
), f"wrong error raised, expected {zmq.ZMQError(errno)} got {zmq.ZMQError(e.errno)}"
else:
pytest.fail(f"Expected {zmq.ZMQError(errno)}, no error raised")
def recv(socket, *, timeout=5, flags=0, multipart=False, **kwargs):
"""call recv[_multipart] in a way that raises if there is nothing to receive"""
if zmq.zmq_version_info() >= (3, 1, 0):
# zmq 3.1 has a bug, where poll can return false positives,
# so we wait a little bit just in case
# See LIBZMQ-280 on JIRA
time.sleep(0.1)
r, w, x = zmq.select([socket], [], [], timeout=timeout)
assert r, "Should have received a message"
kwargs['flags'] = zmq.DONTWAIT | kwargs.get('flags', 0)
recv = socket.recv_multipart if multipart else socket.recv
return recv(flags=flags, **kwargs)
recv_multipart = partial(recv, multipart=True)
@pytest.fixture
def create_bound_pair(socket):
def create_bound_pair(type1=zmq.PAIR, type2=zmq.PAIR, interface='tcp://127.0.0.1'):
"""Create a bound socket pair using a random port."""
s1 = socket(type1)
s1.linger = 0
port = s1.bind_to_random_port(interface)
s2 = socket(type2)
s2.linger = 0
s2.connect(f'{interface}:{port}')
return s1, s2
return create_bound_pair
@pytest.fixture
def bound_pair(create_bound_pair):
return create_bound_pair()
@pytest.fixture
def push_pull(create_bound_pair):
return create_bound_pair(zmq.PUSH, zmq.PULL)
@pytest.fixture
def dealer_router(create_bound_pair):
return create_bound_pair(zmq.DEALER, zmq.ROUTER)