pyzmq/tests/test_monqueue.py

236 lines
8.1 KiB
Python

# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
import threading
import time
import zmq
from zmq import devices
from zmq_test_utils import PYPY, BaseZMQTestCase
if PYPY or zmq.zmq_version_info() >= (4, 1):
# cleanup of shared Context doesn't work on PyPy
# there also seems to be a bug in cleanup in libzmq-4.1 (zeromq/libzmq#1052)
devices.Device.context_factory = zmq.Context
class TestMonitoredQueue(BaseZMQTestCase):
def build_device(self, mon_sub=b"", in_prefix=b'in', out_prefix=b'out'):
self.device = devices.ThreadMonitoredQueue(
zmq.PAIR, zmq.PAIR, zmq.PUB, in_prefix, out_prefix
)
alice = self.context.socket(zmq.PAIR)
bob = self.context.socket(zmq.PAIR)
mon = self.context.socket(zmq.SUB)
aport = alice.bind_to_random_port('tcp://127.0.0.1')
bport = bob.bind_to_random_port('tcp://127.0.0.1')
mport = mon.bind_to_random_port('tcp://127.0.0.1')
mon.setsockopt(zmq.SUBSCRIBE, mon_sub)
self.device.connect_in(f"tcp://127.0.0.1:{aport}")
self.device.connect_out(f"tcp://127.0.0.1:{bport}")
self.device.connect_mon(f"tcp://127.0.0.1:{mport}")
self.device.start()
time.sleep(0.2)
try:
# this is currently necessary to ensure no dropped monitor messages
# see LIBZMQ-248 for more info
mon.recv_multipart(zmq.NOBLOCK)
except zmq.ZMQError:
pass
self.sockets.extend([alice, bob, mon])
return alice, bob, mon
def teardown_device(self):
# spawn term in a background thread
for i in range(50):
# wait for device._context to be populated
context = getattr(self.device, "_context", None)
if context is not None:
break
time.sleep(0.1)
if context is not None:
t = threading.Thread(target=self.device._context.term, daemon=True)
t.start()
for socket in self.sockets:
socket.close()
if context is not None:
t.join(timeout=5)
self.device.join(timeout=5)
def test_reply(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
bobs = self.recv_multipart(bob)
assert alices == bobs
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
self.teardown_device()
def test_queue(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
self.teardown_device()
def test_monitor(self):
alice, bob, mon = self.build_device()
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'in'] + bobs == mons
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
mons = self.recv_multipart(mon)
assert [b'in'] + alices2 == mons
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'in'] + alices3 == mons
mons = self.recv_multipart(mon)
assert [b'out'] + bobs == mons
self.teardown_device()
def test_prefix(self):
alice, bob, mon = self.build_device(b"", b'foo', b'bar')
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'foo'] + bobs == mons
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
mons = self.recv_multipart(mon)
assert [b'foo'] + alices2 == mons
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'foo'] + alices3 == mons
mons = self.recv_multipart(mon)
assert [b'bar'] + bobs == mons
self.teardown_device()
def test_monitor_subscribe(self):
alice, bob, mon = self.build_device(b"out")
alices = b"hello bob".split()
alice.send_multipart(alices)
alices2 = b"hello again".split()
alice.send_multipart(alices2)
alices3 = b"hello again and again".split()
alice.send_multipart(alices3)
bobs = self.recv_multipart(bob)
assert alices == bobs
bobs = self.recv_multipart(bob)
assert alices2 == bobs
bobs = self.recv_multipart(bob)
assert alices3 == bobs
bobs = b"hello alice".split()
bob.send_multipart(bobs)
alices = self.recv_multipart(alice)
assert alices == bobs
mons = self.recv_multipart(mon)
assert [b'out'] + bobs == mons
self.teardown_device()
def test_router_router(self):
"""test router-router MQ devices"""
dev = devices.ThreadMonitoredQueue(
zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out'
)
self.device = dev
dev.setsockopt_in(zmq.LINGER, 0)
dev.setsockopt_out(zmq.LINGER, 0)
dev.setsockopt_mon(zmq.LINGER, 0)
porta = dev.bind_in_to_random_port('tcp://127.0.0.1')
portb = dev.bind_out_to_random_port('tcp://127.0.0.1')
a = self.context.socket(zmq.DEALER)
a.identity = b'a'
b = self.context.socket(zmq.DEALER)
b.identity = b'b'
self.sockets.extend([a, b])
a.connect(f'tcp://127.0.0.1:{porta}')
b.connect(f'tcp://127.0.0.1:{portb}')
dev.start()
time.sleep(1)
if zmq.zmq_version_info() >= (3, 1, 0):
# flush erroneous poll state, due to LIBZMQ-280
ping_msg = [b'ping', b'pong']
for s in (a, b):
s.send_multipart(ping_msg)
try:
s.recv(zmq.NOBLOCK)
except zmq.ZMQError:
pass
msg = [b'hello', b'there']
a.send_multipart([b'b'] + msg)
bmsg = self.recv_multipart(b)
assert bmsg == [b'a'] + msg
b.send_multipart(bmsg)
amsg = self.recv_multipart(a)
assert amsg == [b'b'] + msg
self.teardown_device()
def test_default_mq_args(self):
self.device = dev = devices.ThreadMonitoredQueue(
zmq.ROUTER, zmq.DEALER, zmq.PUB
)
dev.setsockopt_in(zmq.LINGER, 0)
dev.setsockopt_out(zmq.LINGER, 0)
dev.setsockopt_mon(zmq.LINGER, 0)
# this will raise if default args are wrong
dev.start()
self.teardown_device()
def test_mq_check_prefix(self):
ins = self.context.socket(zmq.ROUTER)
outs = self.context.socket(zmq.DEALER)
mons = self.context.socket(zmq.PUB)
self.sockets.extend([ins, outs, mons])
ins = 'in'
outs = 'out'
self.assertRaises(TypeError, devices.monitoredqueue, ins, outs, mons)