mirror of https://github.com/zeromq/pyzmq.git
369 lines
11 KiB
Python
369 lines
11 KiB
Python
# Copyright (C) PyZMQ Developers
|
|
# Distributed under the terms of the Modified BSD License.
|
|
|
|
|
|
import copy
|
|
import gc
|
|
import sys
|
|
|
|
try:
|
|
from sys import getrefcount
|
|
except ImportError:
|
|
grc = None
|
|
else:
|
|
grc = getrefcount
|
|
|
|
import time
|
|
|
|
import zmq
|
|
from zmq_test_utils import PYPY, BaseZMQTestCase, SkipTest, skip_pypy
|
|
|
|
# some useful constants:
|
|
|
|
x = b'x'
|
|
|
|
if grc:
|
|
rc0 = grc(x)
|
|
v = memoryview(x)
|
|
view_rc = grc(x) - rc0
|
|
|
|
|
|
def await_gc(obj, rc):
|
|
"""wait for refcount on an object to drop to an expected value
|
|
|
|
Necessary because of the zero-copy gc thread,
|
|
which can take some time to receive its DECREF message.
|
|
"""
|
|
# count refs for this function
|
|
if sys.version_info < (3, 11):
|
|
my_refs = 2
|
|
else:
|
|
my_refs = 1
|
|
for i in range(50):
|
|
# rc + 2 because of the refs in this function
|
|
if grc(obj) <= rc + my_refs:
|
|
return
|
|
time.sleep(0.05)
|
|
|
|
|
|
class TestFrame(BaseZMQTestCase):
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
for i in range(3):
|
|
gc.collect()
|
|
|
|
@skip_pypy
|
|
def test_above_30(self):
|
|
"""Message above 30 bytes are never copied by 0MQ."""
|
|
for i in range(5, 16): # 32, 64,..., 65536
|
|
s = (2**i) * x
|
|
rc = grc(s)
|
|
m = zmq.Frame(s, copy=False)
|
|
assert grc(s) == rc + 2
|
|
del m
|
|
await_gc(s, rc)
|
|
assert grc(s) == rc
|
|
del s
|
|
|
|
def test_str(self):
|
|
"""Test the str representations of the Frames."""
|
|
m = zmq.Frame(b'')
|
|
assert str(m) == "<zmq.Frame(b'')>"
|
|
m = zmq.Frame(b'123456789')
|
|
assert str(m) == "<zmq.Frame(b'123456789')>"
|
|
m = zmq.Frame(b'x' * 20)
|
|
assert str(m) == "<zmq.Frame(b'xxxxxxxxxxxx'...20B)>"
|
|
m = zmq.Frame(b'x' * 2000)
|
|
assert str(m) == "<zmq.Frame(b'xxxxxxxxxxxx'...2kB)>"
|
|
m = zmq.Frame(b'x' * 2_000_000)
|
|
assert str(m) == "<zmq.Frame(b'xxxxxxxxxxxx'...2MB)>"
|
|
|
|
def test_bytes(self):
|
|
"""Test the Frame.bytes property."""
|
|
for i in range(1, 16):
|
|
s = (2**i) * x
|
|
m = zmq.Frame(s)
|
|
b = m.bytes
|
|
assert s == m.bytes
|
|
if not PYPY:
|
|
# check that it copies
|
|
assert b is not s
|
|
# check that it copies only once
|
|
assert b is m.bytes
|
|
|
|
def test_unicode(self):
|
|
"""Test the unicode representations of the Frames."""
|
|
s = 'asdf'
|
|
self.assertRaises(TypeError, zmq.Frame, s)
|
|
for i in range(16):
|
|
s = (2**i) * '§'
|
|
m = zmq.Frame(s.encode('utf8'))
|
|
assert s == m.bytes.decode('utf8')
|
|
|
|
def test_len(self):
|
|
"""Test the len of the Frames."""
|
|
for i in range(16):
|
|
s = (2**i) * x
|
|
m = zmq.Frame(s)
|
|
assert len(s) == len(m)
|
|
|
|
@skip_pypy
|
|
def test_lifecycle1(self):
|
|
"""Run through a ref counting cycle with a copy."""
|
|
for i in range(5, 16): # 32, 64,..., 65536
|
|
s = (2**i) * x
|
|
rc = rc_0 = grc(s)
|
|
m = zmq.Frame(s, copy=False)
|
|
rc += 2
|
|
assert grc(s) == rc
|
|
m2 = copy.copy(m)
|
|
rc += 1
|
|
assert grc(s) == rc
|
|
# no increase in refcount for accessing buffer
|
|
# which references m2 directly
|
|
buf = m2.buffer
|
|
assert grc(s) == rc
|
|
|
|
assert s == bytes(m2)
|
|
assert s == m.bytes
|
|
assert s == bytes(buf)
|
|
del m2
|
|
assert grc(s) == rc
|
|
# buf holds direct reference to m2 which holds
|
|
del buf
|
|
rc -= 1
|
|
assert grc(s) == rc
|
|
del m
|
|
rc -= 2
|
|
await_gc(s, rc)
|
|
assert grc(s) == rc
|
|
assert rc == rc_0
|
|
del s
|
|
|
|
@skip_pypy
|
|
def test_lifecycle2(self):
|
|
"""Run through a different ref counting cycle with a copy."""
|
|
for i in range(5, 16): # 32, 64,..., 65536
|
|
s = (2**i) * x
|
|
rc = rc_0 = grc(s)
|
|
m = zmq.Frame(s, copy=False)
|
|
rc += 2
|
|
assert grc(s) == rc
|
|
m2 = copy.copy(m)
|
|
rc += 1
|
|
assert grc(s) == rc
|
|
# no increase in refcount for accessing buffer
|
|
# which references m directly
|
|
buf = m.buffer
|
|
assert grc(s) == rc
|
|
assert s == bytes(m2)
|
|
assert s == m2.bytes
|
|
assert s == m.bytes
|
|
assert s == bytes(buf)
|
|
del buf
|
|
assert grc(s) == rc
|
|
del m
|
|
rc -= 1
|
|
assert grc(s) == rc
|
|
del m2
|
|
rc -= 2
|
|
await_gc(s, rc)
|
|
assert grc(s) == rc
|
|
assert rc == rc_0
|
|
del s
|
|
|
|
def test_tracker(self):
|
|
m = zmq.Frame(b'asdf', copy=False, track=True)
|
|
assert not m.tracker.done
|
|
pm = zmq.MessageTracker(m)
|
|
assert not pm.done
|
|
del m
|
|
for i in range(3):
|
|
gc.collect()
|
|
for i in range(10):
|
|
if pm.done:
|
|
break
|
|
time.sleep(0.1)
|
|
assert pm.done
|
|
|
|
def test_no_tracker(self):
|
|
m = zmq.Frame(b'asdf', track=False)
|
|
assert m.tracker is None
|
|
m2 = copy.copy(m)
|
|
assert m2.tracker is None
|
|
self.assertRaises(ValueError, zmq.MessageTracker, m)
|
|
|
|
def test_multi_tracker(self):
|
|
m = zmq.Frame(b'asdf', copy=False, track=True)
|
|
m2 = zmq.Frame(b'whoda', copy=False, track=True)
|
|
mt = zmq.MessageTracker(m, m2)
|
|
assert not m.tracker.done
|
|
assert not mt.done
|
|
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
|
|
del m
|
|
for i in range(3):
|
|
gc.collect()
|
|
self.assertRaises(zmq.NotDone, mt.wait, 0.1)
|
|
assert not mt.done
|
|
del m2
|
|
for i in range(3):
|
|
gc.collect()
|
|
assert mt.wait(0.1) is None
|
|
assert mt.done
|
|
|
|
def test_buffer_in(self):
|
|
"""test using a buffer as input"""
|
|
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
|
|
zmq.Frame(memoryview(ins))
|
|
|
|
def test_bad_buffer_in(self):
|
|
"""test using a bad object"""
|
|
self.assertRaises(TypeError, zmq.Frame, 5)
|
|
self.assertRaises(TypeError, zmq.Frame, object())
|
|
|
|
def test_buffer_out(self):
|
|
"""receiving buffered output"""
|
|
ins = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
|
|
m = zmq.Frame(ins)
|
|
outb = m.buffer
|
|
assert isinstance(outb, memoryview)
|
|
assert outb is m.buffer
|
|
assert m.buffer is m.buffer
|
|
|
|
def test_memoryview_shape(self):
|
|
"""memoryview shape info"""
|
|
data = "§§¶•ªº˜µ¬˚…∆˙åß∂©œ∑´†≈ç√".encode()
|
|
n = len(data)
|
|
f = zmq.Frame(data)
|
|
view1 = f.buffer
|
|
assert view1.ndim == 1
|
|
assert view1.shape == (n,)
|
|
assert view1.tobytes() == data
|
|
view2 = memoryview(f)
|
|
assert view2.ndim == 1
|
|
assert view2.shape == (n,)
|
|
assert view2.tobytes() == data
|
|
|
|
def test_multisend(self):
|
|
"""ensure that a message remains intact after multiple sends"""
|
|
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
|
s = b"message"
|
|
m = zmq.Frame(s)
|
|
assert s == m.bytes
|
|
|
|
a.send(m, copy=False)
|
|
time.sleep(0.1)
|
|
assert s == m.bytes
|
|
a.send(m, copy=False)
|
|
time.sleep(0.1)
|
|
assert s == m.bytes
|
|
a.send(m, copy=True)
|
|
time.sleep(0.1)
|
|
assert s == m.bytes
|
|
a.send(m, copy=True)
|
|
time.sleep(0.1)
|
|
assert s == m.bytes
|
|
for i in range(4):
|
|
r = b.recv()
|
|
assert s == r
|
|
assert s == m.bytes
|
|
|
|
def test_memoryview(self):
|
|
"""test messages from memoryview"""
|
|
s = b'carrotjuice'
|
|
memoryview(s)
|
|
m = zmq.Frame(s)
|
|
buf = m.buffer
|
|
s2 = buf.tobytes()
|
|
assert s2 == s
|
|
assert m.bytes == s
|
|
|
|
def test_noncopying_recv(self):
|
|
"""check for clobbering message buffers"""
|
|
null = b'\0' * 64
|
|
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
|
for i in range(32):
|
|
# try a few times
|
|
sb.send(null, copy=False)
|
|
m = sa.recv(copy=False)
|
|
mb = m.bytes
|
|
# buf = memoryview(m)
|
|
buf = m.buffer
|
|
del m
|
|
for i in range(5):
|
|
ff = b'\xff' * (40 + i * 10)
|
|
sb.send(ff, copy=False)
|
|
m2 = sa.recv(copy=False)
|
|
b = buf.tobytes()
|
|
assert b == null
|
|
assert mb == null
|
|
assert m2.bytes == ff
|
|
assert type(m2.bytes) is bytes
|
|
|
|
def test_noncopying_memoryview(self):
|
|
"""test non-copying memmoryview messages"""
|
|
null = b'\0' * 64
|
|
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
|
for i in range(32):
|
|
# try a few times
|
|
sb.send(memoryview(null), copy=False)
|
|
m = sa.recv(copy=False)
|
|
buf = memoryview(m)
|
|
for i in range(5):
|
|
ff = b'\xff' * (40 + i * 10)
|
|
sb.send(memoryview(ff), copy=False)
|
|
m2 = sa.recv(copy=False)
|
|
buf2 = memoryview(m2)
|
|
assert buf.tobytes() == null
|
|
assert not buf.readonly
|
|
assert buf2.tobytes() == ff
|
|
assert not buf2.readonly
|
|
assert type(buf) is memoryview
|
|
|
|
def test_buffer_numpy(self):
|
|
"""test non-copying numpy array messages"""
|
|
try:
|
|
import numpy
|
|
from numpy.testing import assert_array_equal
|
|
except ImportError:
|
|
raise SkipTest("requires numpy")
|
|
rand = numpy.random.randint
|
|
shapes = [rand(2, 5) for i in range(5)]
|
|
a, b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
|
dtypes = [int, float, '>i4', 'B']
|
|
for i in range(1, len(shapes) + 1):
|
|
shape = shapes[:i]
|
|
for dt in dtypes:
|
|
A = numpy.empty(shape, dtype=dt)
|
|
a.send(A, copy=False)
|
|
msg = b.recv(copy=False)
|
|
|
|
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
|
|
assert_array_equal(A, B)
|
|
|
|
A = numpy.empty(shape, dtype=[('a', int), ('b', float), ('c', 'a32')])
|
|
A['a'] = 1024
|
|
A['b'] = 1e9
|
|
A['c'] = 'hello there'
|
|
a.send(A, copy=False)
|
|
msg = b.recv(copy=False)
|
|
|
|
B = numpy.frombuffer(msg, A.dtype).reshape(A.shape)
|
|
assert_array_equal(A, B)
|
|
|
|
@skip_pypy
|
|
def test_frame_more(self):
|
|
"""test Frame.more attribute"""
|
|
frame = zmq.Frame(b"hello")
|
|
assert not frame.more
|
|
sa, sb = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
|
|
sa.send_multipart([b'hi', b'there'])
|
|
frame = self.recv(sb, copy=False)
|
|
assert frame.more
|
|
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
|
|
assert frame.get(zmq.MORE)
|
|
frame = self.recv(sb, copy=False)
|
|
assert not frame.more
|
|
if zmq.zmq_version_info()[0] >= 3 and not PYPY:
|
|
assert not frame.get(zmq.MORE)
|