diff --git a/messages.py b/messages.py index ddea44f..7d8c140 100644 --- a/messages.py +++ b/messages.py @@ -18,12 +18,20 @@ class InvalidNonceError(Exception): def make_envelope(msgtype, msg, nodeid): msg['nodeid'] = nodeid msg['nonce'] = nonce() - sign = hmac.new(nodeid, json.dumps(msg)) + data = json.dumps(msg) + sign = hmac.new(nodeid, data) envelope = {'data': msg, 'sign': sign.hexdigest(), 'msgtype': msgtype} + #print "make_envelope:", envelope return json.dumps(envelope) +def envelope_decorator(nodeid, func): + msgtype = func.__name__.split("_")[0] + def inner(*args, **kwargs): + return make_envelope(msgtype, func(*args, **kwargs), nodeid) + return inner + # ------ def create_ackhello(nodeid): @@ -42,6 +50,13 @@ def create_pong(nodeid): msg = {} return make_envelope("pong", msg, nodeid) +def create_getaddr(nodeid): + msg = {} + return make_envelope("getaddr", msg, nodeid) + +def create_addr(nodeid, nodes): + msg = {'nodes': nodes} + return make_envelope("addr", msg, nodeid) # ------- def read_envelope(message): @@ -56,6 +71,7 @@ def read_message(message): signature = str(envelope['sign']) msg = json.dumps(envelope['data']) verify_sign = hmac.new(nodeid, msg) + #print "read_message:", msg if hmac.compare_digest(verify_sign.hexdigest(), signature): return envelope['data'] else: diff --git a/network.py b/network.py index bf89348..ad52f81 100644 --- a/network.py +++ b/network.py @@ -1,6 +1,7 @@ import sys from datetime import datetime from time import time +from functools import partial from twisted.internet import reactor from twisted.internet.protocol import Protocol, Factory @@ -12,7 +13,7 @@ from twisted.internet.task import LoopingCall import messages import cryptotools -PING_INTERVAL = 20.0 +PING_INTERVAL = 1200.0 # 20 min = 1200.0 BOOTSTRAP_NODES = ["localhost:5008", "localhost:5007", "localhost:5006", @@ -32,6 +33,7 @@ class NCProtocol(Protocol): self.kind = kind self.nodeid = self.factory.nodeid self.lc_ping = LoopingCall(self.send_PING) + self.message = partial(messages.envelope_decorator, self.nodeid) def connectionMade(self): r_ip = self.transport.getPeer() @@ -56,6 +58,7 @@ class NCProtocol(Protocol): # since ping keeps going if we don't .stop() it. try: self.lc_ping.stop() except AssertionError: pass + try: self.factory.peers.pop(self.remote_nodeid) if self.nodeid != self.remote_nodeid: @@ -79,6 +82,8 @@ class NCProtocol(Protocol): self.handle_PING(line) elif envelope['msgtype'] == 'pong': self.handle_PONG(line) + elif envelope['msgtype'] == 'addr': + self.handle_ADDR(line) def send_PING(self): _print(" [>] PING to", self.remote_nodeid, "at", self.remote_ip) @@ -90,6 +95,42 @@ class NCProtocol(Protocol): pong = messages.create_pong(self.nodeid) self.write(pong) + def send_ADDR(self): + _print(" [>] Telling " + self.remote_nodeid + " about my peers") + # Shouldn't this be a list and not a dict? + peers = self.factory.peers + listeners = [(n, peers[n][0], peers[n][1], peers[n][2]) + for n in peers] + addr = messages.create_addr(self.nodeid, listeners) + self.write(addr) + + def handle_ADDR(self, addr): + try: + nodes = messages.read_message(addr)['nodes'] + _print(" [<] Recieved addr list from peer " + self.remote_nodeid) + #for node in filter(lambda n: nodes[n][1] == "SEND", nodes): + for node in nodes: + _print(" [*] " + node[0] + " " + node[1]) + if node[0] == self.nodeid: + _print(" [!] Not connecting to " + node[0] + ": thats me!") + return + if node[1] != "SEND": + _print(" [ ] Not connecting to " + node[0] + ": is " + node[1]) + return + if node[0] in self.factory.peers: + _print(" [ ] Not connecting to " + node[0] + ": already connected") + return + _print(" [ ] Trying to connect to peer " + node[0] + " " + node[1]) + # TODO: Use [2] and a time limit to not connect to "old" peers + host, port = node[0].split(":") + point = TCP4ClientEndpoint(reactor, host, int(port)) + d = connectProtocol(point, NCProtocol(ncfactory, "SENDHELLO", "SEND")) + d.addCallback(gotProtocol) + except messages.InvalidSignatureError: + print addr + _print(" [!] ERROR: Invalid addr sign ", self.remote_ip) + self.transport.loseConnection() + def handle_PONG(self, pong): pong = messages.read_message(pong) _print(" [<] PONG from", self.remote_nodeid, "at", self.remote_ip) @@ -114,18 +155,24 @@ class NCProtocol(Protocol): if self.state == "GETHELLO": my_hello = messages.create_hello(self.nodeid, self.VERSION) self.transport.write(my_hello + "\n") - entry = (self.remote_ip, self.kind, time()) - self.factory.peers[self.remote_nodeid] = entry + self.add_peer() self.state = "READY" self.print_peers() - self.write(messages.create_ping(self.nodeid)) + #self.write(messages.create_ping(self.nodeid)) if self.kind == "RECV": # The listener pings it's audience + _print(" [ ] Starting pinger to " + self.remote_nodeid) self.lc_ping.start(PING_INTERVAL, now=False) + # Tell new audience about my peers + self.send_ADDR() except messages.InvalidSignatureError: _print(" [!] ERROR: Invalid hello sign ", self.remoteip) self.transport.loseConnection() + def add_peer(self): + entry = (self.remote_ip, self.kind, time()) + self.factory.peers[self.remote_nodeid] = entry + # Splitinto NCRecvFactory and NCSendFactory (also reconsider the names...:/) class NCFactory(Factory): def __init__(self):