pyzmq/zmq/tests/test_zmqstream.py

160 lines
4.1 KiB
Python

# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import asyncio
import logging
import warnings
import pytest
import zmq
import zmq.asyncio
try:
import tornado
from zmq.eventloop import zmqstream
except ImportError:
tornado = None # type: ignore
pytestmark = pytest.mark.usefixtures("io_loop")
@pytest.fixture
async def push_pull(socket):
push = zmqstream.ZMQStream(socket(zmq.PUSH))
pull = zmqstream.ZMQStream(socket(zmq.PULL))
port = push.bind_to_random_port('tcp://127.0.0.1')
pull.connect('tcp://127.0.0.1:%i' % port)
return (push, pull)
@pytest.fixture
def push(push_pull):
push, pull = push_pull
return push
@pytest.fixture
def pull(push_pull):
push, pull = push_pull
return pull
async def test_callable_check(pull):
"""Ensure callable check works."""
pull.on_send(lambda *args: None)
pull.on_recv(lambda *args: None)
with pytest.raises(AssertionError):
pull.on_recv(1)
with pytest.raises(AssertionError):
pull.on_send(1)
with pytest.raises(AssertionError):
pull.on_recv(zmq)
async def test_on_recv_basic(push, pull):
sent = [b'basic']
push.send_multipart(sent)
f = asyncio.Future()
def callback(msg):
f.set_result(msg)
pull.on_recv(callback)
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == sent
async def test_on_recv_wake(push, pull):
sent = [b'wake']
f = asyncio.Future()
pull.on_recv(f.set_result)
await asyncio.sleep(0.5)
push.send_multipart(sent)
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == sent
async def test_on_recv_async(push, pull):
if tornado.version_info < (5,):
pytest.skip()
sent = [b'wake']
f = asyncio.Future()
async def callback(msg):
await asyncio.sleep(0.1)
f.set_result(msg)
pull.on_recv(callback)
await asyncio.sleep(0.5)
push.send_multipart(sent)
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == sent
async def test_on_recv_async_error(push, pull, caplog):
sent = [b'wake']
f = asyncio.Future()
async def callback(msg):
f.set_result(msg)
1 / 0
pull.on_recv(callback)
await asyncio.sleep(0.1)
with caplog.at_level(logging.ERROR, logger=zmqstream.gen_log.name):
push.send_multipart(sent)
recvd = await asyncio.wait_for(f, timeout=5)
assert recvd == sent
# logging error takes a tick later
await asyncio.sleep(0.5)
messages = [
x.message
for x in caplog.get_records("call")
if x.name == zmqstream.gen_log.name
]
assert "Uncaught exception in ZMQStream callback" in "\n".join(messages)
async def test_shadow_socket(context):
with context.socket(zmq.PUSH, socket_class=zmq.asyncio.Socket) as socket:
with pytest.warns(RuntimeWarning):
stream = zmqstream.ZMQStream(socket)
assert type(stream.socket) is zmq.Socket
assert stream.socket.underlying == socket.underlying
stream.close()
async def test_shadow_socket_close(context, caplog):
with context.socket(zmq.PUSH) as push, context.socket(zmq.PULL) as pull:
push.linger = pull.linger = 0
port = push.bind_to_random_port('tcp://127.0.0.1')
pull.connect(f'tcp://127.0.0.1:{port}')
shadow_pull = zmq.Socket.shadow(pull)
stream = zmqstream.ZMQStream(shadow_pull)
# send some messages
for i in range(10):
push.send_string(str(i))
# make sure at least one message has been delivered
pull.recv()
# register callback
# this should schedule event callback on the next tick
stream.on_recv(print)
# close the shadowed socket
pull.close()
# run the event loop, which should see some events on the shadow socket
# but the socket has been closed!
with warnings.catch_warnings(record=True) as records:
await asyncio.sleep(0.2)
warning_text = "\n".join(str(r.message) for r in records)
assert "after closing socket" in warning_text
assert "closed socket" in caplog.text